| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__clipper | ||
| 2 | !! Module containing functions to clip gradients | ||
| 3 | !! | ||
| 4 | !! This module implements clipping methods for layer gradients | ||
| 5 | use coreutils, only: real32 | ||
| 6 | implicit none | ||
| 7 | |||
| 8 | |||
| 9 | private | ||
| 10 | |||
| 11 | public :: clip_type | ||
| 12 | |||
| 13 | |||
| 14 | type clip_type | ||
| 15 | !! Type for clipping gradients | ||
| 16 | logical :: l_min_max = .false. | ||
| 17 | !! Boolean whether min/max values are set | ||
| 18 | logical :: l_norm = .false. | ||
| 19 | !! Boolean whether a norm is set | ||
| 20 | real(real32) :: min =-huge(1._real32) | ||
| 21 | !! Minimum value for clipping | ||
| 22 | real(real32) :: max = huge(1._real32) | ||
| 23 | !! Maximum value for clipping | ||
| 24 | real(real32) :: norm = huge(1._real32) | ||
| 25 | !! Maximum L2-norm for clipping | ||
| 26 | contains | ||
| 27 | procedure, pass(this) :: read => read_clip | ||
| 28 | !! Read clipping information | ||
| 29 | procedure, pass(this) :: set => set_clip | ||
| 30 | !! Set clipping information | ||
| 31 | procedure, pass(this) :: apply => apply_clip | ||
| 32 | !! Apply clipping to gradients | ||
| 33 | end type clip_type | ||
| 34 | |||
| 35 | interface clip_type | ||
| 36 | !! Interface for the clip type | ||
| 37 | module function clip_setup( & | ||
| 38 | clip_min, clip_max, clip_norm) result(clip) | ||
| 39 | !! Set up the clip dictionary | ||
| 40 | real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
| 41 | !! Minimum, maximum, and norm values for clipping | ||
| 42 | type(clip_type) :: clip | ||
| 43 | !! Clip dictionary | ||
| 44 | end function clip_setup | ||
| 45 | end interface clip_type | ||
| 46 | |||
| 47 | |||
| 48 | |||
| 49 | contains | ||
| 50 | |||
| 51 | !############################################################################### | ||
| 52 | − | module function clip_setup( & | |
| 53 | − | clip_min, clip_max, clip_norm) result(clip) | |
| 54 | !! Set up the clip dictionary | ||
| 55 | implicit none | ||
| 56 | |||
| 57 | ! Arguments | ||
| 58 | real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
| 59 | !! Minimum, maximum, and norm values for clipping | ||
| 60 | type(clip_type) :: clip | ||
| 61 | !! Instance of the clip type | ||
| 62 | |||
| 63 | |||
| 64 | !--------------------------------------------------------------------------- | ||
| 65 | ! Set up clipping limits | ||
| 66 | !--------------------------------------------------------------------------- | ||
| 67 | − | if(present(clip_min))then | |
| 68 | − | clip%l_min_max = .true. | |
| 69 | − | clip%min = clip_min | |
| 70 | end if | ||
| 71 | − | if(present(clip_max))then | |
| 72 | − | clip%l_min_max = .true. | |
| 73 | − | clip%max = clip_max | |
| 74 | end if | ||
| 75 | − | if(present(clip_norm))then | |
| 76 | − | clip%l_norm = .true. | |
| 77 | − | clip%norm = clip_norm | |
| 78 | end if | ||
| 79 | |||
| 80 | − | end function clip_setup | |
| 81 | !############################################################################### | ||
| 82 | |||
| 83 | |||
| 84 | !############################################################################### | ||
| 85 | − | subroutine read_clip(this, min_str, max_str, norm_str) | |
| 86 | !! Read clipping information | ||
| 87 | implicit none | ||
| 88 | |||
| 89 | ! Arguments | ||
| 90 | class(clip_type), intent(inout) :: this | ||
| 91 | !! Instance of the clip type | ||
| 92 | character(*), intent(in) :: min_str, max_str, norm_str | ||
| 93 | !! Strings for min, max, and norm values | ||
| 94 | |||
| 95 | − | if(trim(min_str).ne."")then | |
| 96 | − | read(min_str,*) this%min | |
| 97 | else | ||
| 98 | − | this%min = -huge(1._real32) | |
| 99 | end if | ||
| 100 | − | if(trim(max_str).ne."")then | |
| 101 | − | read(max_str,*) this%max | |
| 102 | else | ||
| 103 | − | this%max = huge(1._real32) | |
| 104 | end if | ||
| 105 | |||
| 106 | − | if(trim(min_str).ne."".or.trim(max_str).ne."")then | |
| 107 | − | this%l_min_max = .true. | |
| 108 | end if | ||
| 109 | − | if(trim(norm_str).ne."")then | |
| 110 | − | read(norm_str,*) this%norm | |
| 111 | − | this%l_norm = .true. | |
| 112 | end if | ||
| 113 | |||
| 114 | − | end subroutine read_clip | |
| 115 | !############################################################################### | ||
| 116 | |||
| 117 | |||
| 118 | !############################################################################### | ||
| 119 | − | subroutine set_clip(this, clip_dict, clip_min, clip_max, clip_norm) | |
| 120 | !! Set clipping information | ||
| 121 | implicit none | ||
| 122 | |||
| 123 | ! Arguments | ||
| 124 | class(clip_type), intent(inout) :: this | ||
| 125 | !! Instance of the clip type | ||
| 126 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 127 | !! Clip dictionary | ||
| 128 | real(real32), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
| 129 | !! Minimum, maximum, and norm values for clipping | ||
| 130 | |||
| 131 | |||
| 132 | !--------------------------------------------------------------------------- | ||
| 133 | ! Set up clipping limits | ||
| 134 | !--------------------------------------------------------------------------- | ||
| 135 | − | if(present(clip_dict))then | |
| 136 | − | this%l_min_max = clip_dict%l_min_max | |
| 137 | − | this%l_norm = clip_dict%l_norm | |
| 138 | − | this%min = clip_dict%min | |
| 139 | − | this%max = clip_dict%max | |
| 140 | − | this%norm = clip_dict%norm | |
| 141 | − | if(present(clip_min).or.present(clip_max).or.present(clip_norm))then | |
| 142 | − | write(*,*) "Multiple clip options provided" | |
| 143 | − | write(*,*) "Ignoring all except clip_dict" | |
| 144 | end if | ||
| 145 | else | ||
| 146 | − | if(present(clip_min))then | |
| 147 | − | this%l_min_max = .true. | |
| 148 | − | this%min = clip_min | |
| 149 | end if | ||
| 150 | − | if(present(clip_max))then | |
| 151 | − | this%l_min_max = .true. | |
| 152 | − | this%max = clip_max | |
| 153 | end if | ||
| 154 | − | if(present(clip_norm))then | |
| 155 | − | this%l_norm = .true. | |
| 156 | − | this%norm = clip_norm | |
| 157 | end if | ||
| 158 | end if | ||
| 159 | |||
| 160 | − | end subroutine set_clip | |
| 161 | !############################################################################### | ||
| 162 | |||
| 163 | |||
| 164 | !############################################################################### | ||
| 165 | − | pure subroutine apply_clip(this, length, gradient, bias) | |
| 166 | !! Function to apply clipping to gradients | ||
| 167 | implicit none | ||
| 168 | |||
| 169 | ! Arguments | ||
| 170 | class(clip_type), intent(in) :: this | ||
| 171 | !! Instance of the clip type | ||
| 172 | integer, intent(in) :: length | ||
| 173 | !! Length of the gradient | ||
| 174 | real(real32), dimension(length), intent(inout) :: gradient | ||
| 175 | !! Gradient to be clipped | ||
| 176 | real(real32), dimension(:), optional, intent(inout) :: bias | ||
| 177 | !! Bias to be clipped | ||
| 178 | |||
| 179 | ! Local variables | ||
| 180 | real(real32) :: scale | ||
| 181 | !! Scaling factor for the gradient | ||
| 182 | − | real(real32), dimension(:), allocatable :: bias_ | |
| 183 | !! Copy of the bias | ||
| 184 | |||
| 185 | − | if(present(bias))then | |
| 186 | − | bias_ = bias | |
| 187 | else | ||
| 188 | − | allocate(bias_(1), source=0._real32) | |
| 189 | end if | ||
| 190 | |||
| 191 | ! Clip values to within limits of (min,max) | ||
| 192 | − | if(this%l_min_max)then | |
| 193 | − | gradient = max(this%min,min(this%max,gradient)) | |
| 194 | − | bias_ = max(this%min,min(this%max,bias_)) | |
| 195 | end if | ||
| 196 | |||
| 197 | ! Clip values to a maximum L2-norm | ||
| 198 | − | if(this%l_norm)then | |
| 199 | scale = min(1._real32, & | ||
| 200 | − | this%norm/sqrt(sum(gradient**2._real32) + & | |
| 201 | − | sum(bias_)**2._real32)) | |
| 202 | − | if(scale.lt.1._real32)then | |
| 203 | − | gradient = gradient * scale | |
| 204 | − | bias_ = bias_ * scale | |
| 205 | end if | ||
| 206 | end if | ||
| 207 | |||
| 208 | − | if(present(bias)) bias = bias_ | |
| 209 | |||
| 210 | − | end subroutine apply_clip | |
| 211 | !############################################################################### | ||
| 212 | |||
| 213 | − | end module athena__clipper | |
| 214 |