| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__metrics | ||
| 2 | !! Module containing functions to compute the accuracy of a model | ||
| 3 | !! | ||
| 4 | !! This module contains a derived type for storing and handling metric data | ||
| 5 | use coreutils, only: real32, stop_program | ||
| 6 | implicit none | ||
| 7 | |||
| 8 | |||
| 9 | private | ||
| 10 | |||
| 11 | public :: metric_dict_type | ||
| 12 | public :: metric_dict_alloc | ||
| 13 | |||
| 14 | |||
| 15 | type :: metric_dict_type | ||
| 16 | !! Type for storing and handling metric data | ||
| 17 | character(10) :: key | ||
| 18 | !! Key for the metric | ||
| 19 | real(real32) :: val | ||
| 20 | !! Value of the metric | ||
| 21 | logical :: active | ||
| 22 | !! Flag to indicate if the metric is active | ||
| 23 | real(real32) :: threshold | ||
| 24 | !! Threshold for the metric | ||
| 25 | integer :: window_width | ||
| 26 | !! Window width for checking convergence | ||
| 27 | integer :: num_entries | ||
| 28 | !! Number of entries in the history | ||
| 29 | real(real32), allocatable, dimension(:) :: history | ||
| 30 | !! History of the metric | ||
| 31 | contains | ||
| 32 | procedure :: check => metric_dict_check | ||
| 33 | !! Check if the metric has converged | ||
| 34 | procedure :: add_t_t => metric_dict_add | ||
| 35 | !! Add two metric_dict_type together | ||
| 36 | procedure :: append => append_value | ||
| 37 | !! Append a value to the history of the metric | ||
| 38 | generic :: operator(+) => add_t_t | ||
| 39 | !! Overload the addition operator | ||
| 40 | end type metric_dict_type | ||
| 41 | |||
| 42 | |||
| 43 | |||
| 44 | contains | ||
| 45 | |||
| 46 | !############################################################################### | ||
| 47 | − | elemental function metric_dict_add(a, b) result(output) | |
| 48 | !! Operation to add two metric_dict_type together | ||
| 49 | implicit none | ||
| 50 | |||
| 51 | ! Arguments | ||
| 52 | class(metric_dict_type), intent(in) :: a,b | ||
| 53 | !! Instances of metric data | ||
| 54 | type(metric_dict_type) :: output | ||
| 55 | !! Sum of the metric data | ||
| 56 | |||
| 57 | − | output%key = a%key | |
| 58 | − | output%val = a%val + b%val | |
| 59 | − | output%threshold = a%threshold | |
| 60 | − | output%active = a%active | |
| 61 | − | if(allocated(a%history)) output%history = a%history | |
| 62 | − | output%num_entries = a%num_entries | |
| 63 | |||
| 64 | − | end function metric_dict_add | |
| 65 | !############################################################################### | ||
| 66 | |||
| 67 | |||
| 68 | !############################################################################### | ||
| 69 | − | subroutine metric_dict_alloc(input, source, length) | |
| 70 | !! Allocate memory for a metric_dict_type | ||
| 71 | implicit none | ||
| 72 | |||
| 73 | ! Arguments | ||
| 74 | type(metric_dict_type), dimension(:), intent(out) :: input | ||
| 75 | !! Instance of metric data | ||
| 76 | type(metric_dict_type), dimension(:), optional, intent(in) :: source | ||
| 77 | !! Source of the metric data to copy | ||
| 78 | integer, optional, intent(in) :: length | ||
| 79 | !! Length of the metric data | ||
| 80 | |||
| 81 | ! Local variables | ||
| 82 | integer :: i | ||
| 83 | !! Loop index | ||
| 84 | |||
| 85 | |||
| 86 | − | if(present(length))then | |
| 87 | − | do i=1,size(input,dim=1) | |
| 88 | − | allocate(input(i)%history(length)) | |
| 89 | end do | ||
| 90 | else | ||
| 91 | − | if(present(source))then | |
| 92 | − | do i=1, size(input,dim=1) | |
| 93 | − | input(i)%key = source(i)%key | |
| 94 | − | allocate(input(i)%history(size(source(i)%history,dim=1))) | |
| 95 | − | input(i)%threshold = source(i)%threshold | |
| 96 | end do | ||
| 97 | else | ||
| 98 | call stop_program( & | ||
| 99 | "metric_dict_alloc requires either a source or length" & | ||
| 100 | − | ) | |
| 101 | end if | ||
| 102 | end if | ||
| 103 | − | input%num_entries = 0 | |
| 104 | |||
| 105 | − | end subroutine metric_dict_alloc | |
| 106 | !############################################################################### | ||
| 107 | |||
| 108 | |||
| 109 | !############################################################################### | ||
| 110 | − | subroutine append_value(this, value) | |
| 111 | !! Append a value to the history of the metric | ||
| 112 | implicit none | ||
| 113 | |||
| 114 | ! Arguments | ||
| 115 | class(metric_dict_type), intent(inout) :: this | ||
| 116 | !! Instance of metric data | ||
| 117 | real(real32), intent(in) :: value | ||
| 118 | !! Value to append | ||
| 119 | |||
| 120 | ! Local variables | ||
| 121 | integer :: new_size | ||
| 122 | |||
| 123 | − | this%val = value | |
| 124 | − | if(.not.allocated(this%history)) then | |
| 125 | − | allocate(this%history(this%window_width), source = -huge(1._real32)) | |
| 126 | − | this%history(this%window_width) = value | |
| 127 | − | this%num_entries = 0 | |
| 128 | − | elseif(this%num_entries .lt. this%window_width) then | |
| 129 | − | this%history(this%num_entries) = value | |
| 130 | else | ||
| 131 | − | this%history = [ this%history, value ] | |
| 132 | end if | ||
| 133 | − | this%num_entries = this%num_entries + 1 | |
| 134 | |||
| 135 | − | end subroutine append_value | |
| 136 | !############################################################################### | ||
| 137 | |||
| 138 | |||
| 139 | !############################################################################### | ||
| 140 | − | subroutine metric_dict_check(this,plateau_threshold,converged) | |
| 141 | !! Check if the metric has converged | ||
| 142 | implicit none | ||
| 143 | |||
| 144 | ! Arguments | ||
| 145 | class(metric_dict_type), intent(inout) :: this | ||
| 146 | !! Instance of metric data | ||
| 147 | real(real32), intent(in) :: plateau_threshold | ||
| 148 | !! Threshold for plateau | ||
| 149 | integer, intent(out) :: converged | ||
| 150 | !! Boolean whether the metric has converged | ||
| 151 | |||
| 152 | ! Local variables | ||
| 153 | integer :: window_width | ||
| 154 | !! Width of the convergence check window | ||
| 155 | integer :: window_ubound, window_lbound | ||
| 156 | !! Upper and lower bounds of the window | ||
| 157 | |||
| 158 | − | converged = 0 | |
| 159 | − | window_width = min(this%window_width, this%num_entries) | |
| 160 | − | if(window_width .le. 0) then | |
| 161 | − | call stop_program("Window width is zero or negative") | |
| 162 | − | return | |
| 163 | end if | ||
| 164 | − | window_ubound = this%num_entries | |
| 165 | − | window_lbound = window_ubound - window_width + 1 | |
| 166 | − | if(this%active)then | |
| 167 | if( & | ||
| 168 | ( & | ||
| 169 | trim(this%key).eq."loss".and.& | ||
| 170 | abs( & | ||
| 171 | − | sum( this%history(window_lbound:window_ubound) ) & | |
| 172 | ) / window_width.lt.& | ||
| 173 | this%threshold & | ||
| 174 | − | ) .or. & | |
| 175 | ( & | ||
| 176 | trim(this%key).eq."accuracy".and.& | ||
| 177 | abs( & | ||
| 178 | − | sum( 1._real32 - this%history(window_lbound:window_ubound) ) & | |
| 179 | ) / window_width.lt.& | ||
| 180 | this%threshold & | ||
| 181 | ) & | ||
| 182 | )then | ||
| 183 | write(6,*) & | ||
| 184 | − | "Convergence achieved, "//trim(this%key)//" threshold reached" | |
| 185 | − | write(6,*) "Exiting training loop" | |
| 186 | − | converged = 1 | |
| 187 | elseif( & | ||
| 188 | − | all( abs(this%history(window_lbound:window_ubound) - this%val) .lt. & | |
| 189 | plateau_threshold & | ||
| 190 | ) & | ||
| 191 | )then | ||
| 192 | write(0,'("ERROR: ",A," has remained constant for ",I0," runs")') & | ||
| 193 | − | trim(this%key), size(this%history,dim=1) | |
| 194 | − | write(0,*) this%history | |
| 195 | − | write(0,*) "Exiting..." | |
| 196 | − | converged = -1 | |
| 197 | end if | ||
| 198 | end if | ||
| 199 | |||
| 200 | end subroutine metric_dict_check | ||
| 201 | !############################################################################### | ||
| 202 | |||
| 203 | − | end module athena__metrics | |
| 204 |