| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__metrics | ||
| 2 | !! Module containing functions to compute the accuracy of a model | ||
| 3 | !! | ||
| 4 | !! This module contains a derived type for storing and handling metric data | ||
| 5 | use coreutils, only: real32, stop_program | ||
| 6 | implicit none | ||
| 7 | |||
| 8 | |||
| 9 | private | ||
| 10 | |||
| 11 | public :: metric_dict_type | ||
| 12 | public :: metric_dict_alloc | ||
| 13 | |||
| 14 | |||
| 15 | type :: metric_dict_type | ||
| 16 | !! Type for storing and handling metric data | ||
| 17 | character(10) :: key | ||
| 18 | !! Key for the metric | ||
| 19 | real(real32) :: val | ||
| 20 | !! Value of the metric | ||
| 21 | logical :: active | ||
| 22 | !! Flag to indicate if the metric is active | ||
| 23 | real(real32) :: threshold | ||
| 24 | !! Threshold for the metric | ||
| 25 | integer :: window_width | ||
| 26 | !! Window width for checking convergence | ||
| 27 | integer :: num_entries | ||
| 28 | !! Number of entries in the history | ||
| 29 | real(real32), allocatable, dimension(:) :: history | ||
| 30 | !! History of the metric | ||
| 31 | contains | ||
| 32 | procedure :: check => metric_dict_check | ||
| 33 | !! Check if the metric has converged | ||
| 34 | procedure :: add_t_t => metric_dict_add | ||
| 35 | !! Add two metric_dict_type together | ||
| 36 | procedure :: append => append_value | ||
| 37 | !! Append a value to the history of the metric | ||
| 38 | generic :: operator(+) => add_t_t | ||
| 39 | !! Overload the addition operator | ||
| 40 | end type metric_dict_type | ||
| 41 | |||
| 42 | |||
| 43 | |||
| 44 | contains | ||
| 45 | |||
| 46 | !############################################################################### | ||
| 47 | 1 | elemental function metric_dict_add(a, b) result(output) | |
| 48 | !! Operation to add two metric_dict_type together | ||
| 49 | implicit none | ||
| 50 | |||
| 51 | ! Arguments | ||
| 52 | class(metric_dict_type), intent(in) :: a,b | ||
| 53 | !! Instances of metric data | ||
| 54 | type(metric_dict_type) :: output | ||
| 55 | !! Sum of the metric data | ||
| 56 | |||
| 57 | 1 | output%key = a%key | |
| 58 | 1 | output%val = a%val + b%val | |
| 59 | 1 | output%threshold = a%threshold | |
| 60 | 1 | output%active = a%active | |
| 61 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✓ Branch 23 taken 1 times.
|
6 | if(allocated(a%history)) output%history = a%history |
| 62 | 1 | output%num_entries = a%num_entries | |
| 63 | |||
| 64 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
2 | end function metric_dict_add |
| 65 | !############################################################################### | ||
| 66 | |||
| 67 | |||
| 68 | !############################################################################### | ||
| 69 |
16/24✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✓ Branch 9 taken 2 times.
✓ Branch 10 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 4 times.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 4 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
|
10 | subroutine metric_dict_alloc(input, source, length) |
| 70 | !! Allocate memory for a metric_dict_type | ||
| 71 | implicit none | ||
| 72 | |||
| 73 | ! Arguments | ||
| 74 | type(metric_dict_type), dimension(:), intent(out) :: input | ||
| 75 | !! Instance of metric data | ||
| 76 | type(metric_dict_type), dimension(:), optional, intent(in) :: source | ||
| 77 | !! Source of the metric data to copy | ||
| 78 | integer, optional, intent(in) :: length | ||
| 79 | !! Length of the metric data | ||
| 80 | |||
| 81 | ! Local variables | ||
| 82 | integer :: i | ||
| 83 | !! Loop index | ||
| 84 | |||
| 85 | |||
| 86 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | if(present(length))then |
| 87 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 9 taken 2 times.
✓ Branch 10 taken 1 times.
|
3 | do i=1,size(input,dim=1) |
| 88 |
9/18✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
|
3 | allocate(input(i)%history(length)) |
| 89 | end do | ||
| 90 | else | ||
| 91 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(source))then |
| 92 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 9 taken 2 times.
✓ Branch 10 taken 1 times.
|
3 | do i=1, size(input,dim=1) |
| 93 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | input(i)%key = source(i)%key |
| 94 |
11/22✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
|
2 | allocate(input(i)%history(size(source(i)%history,dim=1))) |
| 95 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
3 | input(i)%threshold = source(i)%threshold |
| 96 | end do | ||
| 97 | else | ||
| 98 | call stop_program( & | ||
| 99 | "metric_dict_alloc requires either a source or length" & | ||
| 100 | ✗ | ) | |
| 101 | end if | ||
| 102 | end if | ||
| 103 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 9 taken 4 times.
✓ Branch 10 taken 2 times.
|
6 | input%num_entries = 0 |
| 104 | |||
| 105 | 2 | end subroutine metric_dict_alloc | |
| 106 | !############################################################################### | ||
| 107 | |||
| 108 | |||
| 109 | !############################################################################### | ||
| 110 | 2232 | subroutine append_value(this, value) | |
| 111 | !! Append a value to the history of the metric | ||
| 112 | implicit none | ||
| 113 | |||
| 114 | ! Arguments | ||
| 115 | class(metric_dict_type), intent(inout) :: this | ||
| 116 | !! Instance of metric data | ||
| 117 | real(real32), intent(in) :: value | ||
| 118 | !! Value to append | ||
| 119 | |||
| 120 | ! Local variables | ||
| 121 | integer :: new_size | ||
| 122 | |||
| 123 | 2232 | this%val = value | |
| 124 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2224 times.
|
2232 | if(.not.allocated(this%history))then |
| 125 |
13/24✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 8 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 8 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 8 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 8 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 8 times.
✓ Branch 29 taken 3010 times.
✓ Branch 30 taken 8 times.
|
3018 | allocate(this%history(this%window_width), source = -huge(1._real32)) |
| 126 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
|
8 | this%history(this%window_width) = value |
| 127 | 8 | this%num_entries = 0 | |
| 128 |
2/2✓ Branch 0 taken 1022 times.
✓ Branch 1 taken 1202 times.
|
2224 | elseif(this%num_entries .lt. this%window_width)then |
| 129 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1022 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1022 times.
|
1022 | this%history(this%num_entries) = value |
| 130 | else | ||
| 131 |
15/24✗ Branch 0 not taken.
✓ Branch 1 taken 1202 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1202 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1202 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1202 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1202 times.
✓ Branch 15 taken 961600 times.
✓ Branch 16 taken 1202 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1202 times.
✓ Branch 19 taken 962802 times.
✓ Branch 20 taken 1202 times.
✓ Branch 21 taken 1202 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1202 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1202 times.
✓ Branch 27 taken 962802 times.
✓ Branch 28 taken 1202 times.
|
2888406 | this%history = [ this%history, value ] |
| 132 | end if | ||
| 133 | 2232 | this%num_entries = this%num_entries + 1 | |
| 134 | |||
| 135 | 2232 | end subroutine append_value | |
| 136 | !############################################################################### | ||
| 137 | |||
| 138 | |||
| 139 | !############################################################################### | ||
| 140 | 1621 | subroutine metric_dict_check(this,plateau_threshold,converged) | |
| 141 | !! Check if the metric has converged | ||
| 142 | implicit none | ||
| 143 | |||
| 144 | ! Arguments | ||
| 145 | class(metric_dict_type), intent(inout) :: this | ||
| 146 | !! Instance of metric data | ||
| 147 | real(real32), intent(in) :: plateau_threshold | ||
| 148 | !! Threshold for plateau | ||
| 149 | integer, intent(out) :: converged | ||
| 150 | !! Boolean whether the metric has converged | ||
| 151 | |||
| 152 | ! Local variables | ||
| 153 | integer :: window_width | ||
| 154 | !! Width of the convergence check window | ||
| 155 | integer :: window_ubound, window_lbound | ||
| 156 | !! Upper and lower bounds of the window | ||
| 157 | |||
| 158 | 1621 | converged = 0 | |
| 159 | 1621 | window_width = min(this%window_width, this%num_entries) | |
| 160 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1621 times.
|
1621 | if(window_width .le. 0)then |
| 161 | ✗ | call stop_program("Window width is zero or negative") | |
| 162 | ✗ | return | |
| 163 | end if | ||
| 164 | 1621 | window_ubound = this%num_entries | |
| 165 | 1621 | window_lbound = window_ubound - window_width + 1 | |
| 166 |
2/2✓ Branch 0 taken 512 times.
✓ Branch 1 taken 1109 times.
|
1621 | if(this%active)then |
| 167 | if( & | ||
| 168 | ( & | ||
| 169 | trim(this%key).eq."loss".and.& | ||
| 170 | abs( & | ||
| 171 | 2048 | sum( this%history(window_lbound:window_ubound) ) & | |
| 172 | ) / window_width.lt.& | ||
| 173 | this%threshold & | ||
| 174 |
12/18✗ Branch 1 not taken.
✓ Branch 2 taken 512 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 512 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 512 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 512 times.
✓ Branch 13 taken 125290 times.
✓ Branch 14 taken 512 times.
✓ Branch 16 taken 125290 times.
✓ Branch 17 taken 512 times.
✓ Branch 18 taken 512 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 512 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 510 times.
|
251604 | ) .or. & |
| 175 | ( & | ||
| 176 | trim(this%key).eq."accuracy".and.& | ||
| 177 | abs( & | ||
| 178 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 512 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 512 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 512 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 512 times.
|
512 | sum( 1._real32 - this%history(window_lbound:window_ubound) ) & |
| 179 | ) / window_width.lt.& | ||
| 180 | this%threshold & | ||
| 181 | ) & | ||
| 182 | )then | ||
| 183 | write(6,*) & | ||
| 184 |
1/2✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
|
2 | "Convergence achieved, "//trim(this%key)//" threshold reached" |
| 185 | 2 | write(6,*) "Exiting training loop" | |
| 186 | 2 | converged = 1 | |
| 187 | elseif( & | ||
| 188 |
10/14✗ Branch 0 not taken.
✓ Branch 1 taken 510 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 510 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 510 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 510 times.
✓ Branch 12 taken 514 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 509 times.
✓ Branch 15 taken 5 times.
✓ Branch 16 taken 1 times.
✓ Branch 17 taken 509 times.
|
515 | all( abs(this%history(window_lbound:window_ubound) - this%val) .lt. & |
| 189 | plateau_threshold & | ||
| 190 | ) & | ||
| 191 | )then | ||
| 192 | write(0,'("ERROR: ",A," has remained constant for ",I0," runs")') & | ||
| 193 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | trim(this%key), size(this%history,dim=1) |
| 194 | 1 | write(0,*) this%history | |
| 195 | 1 | write(0,*) "Exiting..." | |
| 196 | 1 | converged = -1 | |
| 197 | end if | ||
| 198 | end if | ||
| 199 | |||
| 200 | end subroutine metric_dict_check | ||
| 201 | !############################################################################### | ||
| 202 | |||
| 203 | ✗ | end module athena__metrics | |
| 204 |