| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_batchnorm | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | − | module function batchnorm_inference( & | |
| 8 | − | input, params, mean, variance, epsilon & | |
| 9 | ) result( output ) | ||
| 10 | implicit none | ||
| 11 | class(array_type), intent(in), target :: input | ||
| 12 | class(array_type), intent(in), target :: params | ||
| 13 | real(real32), dimension(:), intent(in) :: mean | ||
| 14 | real(real32), dimension(:), intent(in) :: variance | ||
| 15 | real(real32), intent(in) :: epsilon | ||
| 16 | type(batchnorm_array_type), pointer :: output | ||
| 17 | |||
| 18 | integer :: i, c, s | ||
| 19 | integer :: num_elements, num_dims | ||
| 20 | |||
| 21 | − | call output%allocate(array_shape = input%shape) | |
| 22 | − | output%epsilon = epsilon | |
| 23 | − | output%mean = mean | |
| 24 | − | output%variance = variance | |
| 25 | − | num_dims = size(input%shape) | |
| 26 | − | num_elements = product(input%shape(1:num_dims - 1)) | |
| 27 | − | do concurrent(c = 1:input%shape(num_dims)) | |
| 28 | − | do concurrent(i=1:num_elements) | |
| 29 | − | output%val(i + (c-1) * num_elements:i + c * num_elements - 1, s) = & | |
| 30 | − | params%val(c,1) * ( & | |
| 31 | − | input%val(i + (c-1) * num_elements:i + c * num_elements - 1, s) & | |
| 32 | − | - mean(c) & | |
| 33 | ) / & | ||
| 34 | − | sqrt(variance(c) + output%epsilon) + & | |
| 35 | − | params%val(c+input%shape(num_dims),1) | |
| 36 | end do | ||
| 37 | end do | ||
| 38 | |||
| 39 | − | end function batchnorm_inference | |
| 40 | !------------------------------------------------------------------------------- | ||
| 41 | − | module function batchnorm( & | |
| 42 | − | input, params, momentum, mean, variance, epsilon & | |
| 43 | ) result( output ) | ||
| 44 | !! Batch normalisation operation | ||
| 45 | implicit none | ||
| 46 | |||
| 47 | ! Arguments | ||
| 48 | class(array_type), intent(in), target :: input | ||
| 49 | class(array_type), intent(in), target :: params | ||
| 50 | real(real32), intent(in) :: momentum | ||
| 51 | real(real32), dimension(:), intent(in) :: mean | ||
| 52 | real(real32), dimension(:), intent(in) :: variance | ||
| 53 | real(real32), intent(in) :: epsilon | ||
| 54 | type(batchnorm_array_type), pointer :: output | ||
| 55 | |||
| 56 | ! Local variables | ||
| 57 | integer :: i, c, s | ||
| 58 | integer :: num_elements, num_dims | ||
| 59 | real(real32) :: mu, var, norm | ||
| 60 | |||
| 61 | − | allocate(output) | |
| 62 | − | if(output%allocated) call output%deallocate() | |
| 63 | − | call output%allocate(array_shape = [ input%shape, size(input%val,2) ]) | |
| 64 | − | output%epsilon = epsilon | |
| 65 | − | output%mean = mean | |
| 66 | − | output%variance = variance | |
| 67 | − | num_dims = size(input%shape) | |
| 68 | − | num_elements = product(input%shape(1:num_dims - 1)) | |
| 69 | − | norm = real(num_elements * size(input%val,2), real32) | |
| 70 | − | do concurrent(c = 1:input%shape(num_dims)) | |
| 71 | − | mu = 0._real32 | |
| 72 | − | var = 0._real32 | |
| 73 | − | mu = sum(input%val((c-1) * num_elements+1:c*num_elements,:)) / norm | |
| 74 | − | var = sum( (input%val((c-1) * num_elements+1:c*num_elements,:) - mu) ** 2 ) / & | |
| 75 | − | norm | |
| 76 | |||
| 77 | − | if(momentum .gt. 1.E-8_real32) then | |
| 78 | − | output%mean(c) = momentum * mean(c) + (1._real32 - momentum) * mu | |
| 79 | − | output%variance(c) = momentum * variance(c) + (1._real32 - momentum) * var | |
| 80 | else | ||
| 81 | − | output%mean(c) = mu | |
| 82 | − | output%variance(c) = var | |
| 83 | end if | ||
| 84 | |||
| 85 | − | do concurrent(s = 1:size(input%val,2), i = 1:num_elements) | |
| 86 | − | output%val(i + (c-1) * num_elements, s) = & | |
| 87 | − | params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - mu ) / & | |
| 88 | − | sqrt(var + output%epsilon) + params%val(c+input%shape(num_dims),1) | |
| 89 | end do | ||
| 90 | end do | ||
| 91 | |||
| 92 | − | output%get_partial_left => get_partial_batchnorm_left | |
| 93 | − | output%get_partial_left_val => get_partial_batchnorm_left_val | |
| 94 | − | output%get_partial_right => get_partial_batchnorm_right | |
| 95 | − | output%get_partial_right_val => get_partial_batchnorm_right_val | |
| 96 | − | if(input%requires_grad .or. params%requires_grad) then | |
| 97 | − | output%requires_grad = .true. | |
| 98 | − | output%is_forward = input%is_forward .or. params%is_forward | |
| 99 | − | output%operation = 'batchnorm' | |
| 100 | − | output%left_operand => input | |
| 101 | − | output%right_operand => params | |
| 102 | end if | ||
| 103 | |||
| 104 | − | end function batchnorm | |
| 105 | !------------------------------------------------------------------------------- | ||
| 106 | − | function get_partial_batchnorm_left(this, upstream_grad) result(output) | |
| 107 | implicit none | ||
| 108 | class(array_type), intent(inout) :: this | ||
| 109 | type(array_type), intent(in) :: upstream_grad | ||
| 110 | type(array_type) :: output | ||
| 111 | |||
| 112 | class(array_type), pointer :: input | ||
| 113 | |||
| 114 | − | input => this%left_operand | |
| 115 | − | call output%allocate(array_shape = [ input%shape, size(upstream_grad%val,2) ]) | |
| 116 | |||
| 117 | − | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 118 | |||
| 119 | − | end function get_partial_batchnorm_left | |
| 120 | !------------------------------------------------------------------------------- | ||
| 121 | − | function get_partial_batchnorm_right(this, upstream_grad) result(output) | |
| 122 | implicit none | ||
| 123 | class(array_type), intent(inout) :: this | ||
| 124 | type(array_type), intent(in) :: upstream_grad | ||
| 125 | type(array_type) :: output | ||
| 126 | |||
| 127 | class(array_type), pointer :: params | ||
| 128 | |||
| 129 | − | params => this%right_operand | |
| 130 | − | call output%allocate(array_shape = [ params%shape, 1 ]) | |
| 131 | |||
| 132 | − | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 133 | |||
| 134 | − | end function get_partial_batchnorm_right | |
| 135 | !------------------------------------------------------------------------------- | ||
| 136 | − | pure subroutine get_partial_batchnorm_left_val(this, upstream_grad, output) | |
| 137 | !! Get partial derivative wrt input for batchnorm (subroutine version) | ||
| 138 | implicit none | ||
| 139 | |||
| 140 | class(array_type), intent(in) :: this | ||
| 141 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 142 | real(real32), dimension(:,:), intent(out) :: output | ||
| 143 | |||
| 144 | integer :: i, c, s, num_dims, num_elements | ||
| 145 | − | real(real32), allocatable :: x_hat(:,:), dx_hat(:,:) | |
| 146 | real(real32) :: mu, var, eps, norm | ||
| 147 | − | integer, dimension(size(this%shape)) :: input_shape | |
| 148 | |||
| 149 | − | input_shape = this%left_operand%shape | |
| 150 | |||
| 151 | select type(this) | ||
| 152 | type is (batchnorm_array_type) | ||
| 153 | − | eps = this%epsilon | |
| 154 | − | num_dims = size(this%shape) | |
| 155 | − | num_elements = product(this%shape(1:num_dims - 1)) | |
| 156 | |||
| 157 | − | output = 0._real32 | |
| 158 | |||
| 159 | − | allocate(x_hat(num_elements, size(upstream_grad,2))) | |
| 160 | − | allocate(dx_hat(num_elements, size(upstream_grad,2))) | |
| 161 | norm = real( & | ||
| 162 | − | product(input_shape(1:num_dims - 1)) * size(upstream_grad,2), & | |
| 163 | real32 & | ||
| 164 | − | ) | |
| 165 | |||
| 166 | − | do c = 1, input_shape(num_dims) | |
| 167 | − | mu = this%mean(c) | |
| 168 | − | var = this%variance(c) | |
| 169 | |||
| 170 | ! Normalised input | ||
| 171 | − | x_hat = ( & | |
| 172 | this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - & | ||
| 173 | mu & | ||
| 174 | − | ) / sqrt(var + eps) | |
| 175 | |||
| 176 | ! Gradient of normalised input | ||
| 177 | − | dx_hat = upstream_grad((c-1)*num_elements+1:c*num_elements,:) * & | |
| 178 | − | this%right_operand%val(c,1) | |
| 179 | |||
| 180 | ! Gradient wrt input | ||
| 181 | − | do concurrent(s = 1:size(upstream_grad,2), i = 1:num_elements) | |
| 182 | − | output(i + (c-1)*num_elements,s) = & | |
| 183 | (1._real32 / (norm * sqrt(var + eps))) * & | ||
| 184 | − | (norm * dx_hat(i,s) - sum(dx_hat) - x_hat(i,s) * sum(dx_hat * x_hat)) | |
| 185 | end do | ||
| 186 | end do | ||
| 187 | end select | ||
| 188 | |||
| 189 | − | end subroutine get_partial_batchnorm_left_val | |
| 190 | !------------------------------------------------------------------------------- | ||
| 191 | − | pure subroutine get_partial_batchnorm_right_val(this, upstream_grad, output) | |
| 192 | !! Get partial derivative wrt params for batchnorm (subroutine version) | ||
| 193 | implicit none | ||
| 194 | |||
| 195 | class(array_type), intent(in) :: this | ||
| 196 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 197 | real(real32), dimension(:,:), intent(out) :: output | ||
| 198 | |||
| 199 | integer :: c, num_dims, num_elements | ||
| 200 | − | real(real32), allocatable :: x_hat(:,:) | |
| 201 | real(real32) :: mu, var, eps | ||
| 202 | − | integer, dimension(size(this%shape)) :: input_shape | |
| 203 | |||
| 204 | − | input_shape = this%left_operand%shape | |
| 205 | |||
| 206 | select type(this) | ||
| 207 | type is (batchnorm_array_type) | ||
| 208 | − | eps = this%epsilon | |
| 209 | − | num_dims = size(this%shape) | |
| 210 | − | num_elements = product(this%shape(1:num_dims - 1)) | |
| 211 | |||
| 212 | − | output = 0._real32 | |
| 213 | |||
| 214 | − | allocate(x_hat(num_elements, size(upstream_grad,2))) | |
| 215 | |||
| 216 | − | do c = 1, input_shape(num_dims) | |
| 217 | − | mu = this%mean(c) | |
| 218 | − | var = this%variance(c) | |
| 219 | |||
| 220 | ! Normalised input | ||
| 221 | − | x_hat(:,:) = ( & | |
| 222 | this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - mu & | ||
| 223 | − | ) / sqrt(var + eps) | |
| 224 | |||
| 225 | − | output(c,1) = & | |
| 226 | − | sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:) * x_hat) | |
| 227 | − | output(c + input_shape(num_dims),1) = & | |
| 228 | − | sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:)) | |
| 229 | |||
| 230 | end do | ||
| 231 | end select | ||
| 232 | |||
| 233 | − | end subroutine get_partial_batchnorm_right_val | |
| 234 | !############################################################################### | ||
| 235 | |||
| 236 |
2/4✓ Branch 0 (27→28) taken 3 times.
✗ Branch 1 (27→335) not taken.
✓ Branch 2 (27→28) taken 3 times.
✗ Branch 3 (27→226) not taken.
|
6 | end submodule athena__diffstruc_extd_submodule_batchnorm |
| 237 |