| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | use coreutils, only: stop_program | ||
| 4 | use diffstruc, only: & | ||
| 5 | operator(+), operator(-), operator(*), concat, exp, sum, merge | ||
| 6 | |||
| 7 | contains | ||
| 8 | |||
| 9 | !############################################################################### | ||
| 10 | − | module function add_array_ptr(a, idx1, idx2) result(c) | |
| 11 | !! Add two autodiff arrays | ||
| 12 | implicit none | ||
| 13 | |||
| 14 | ! Arguments | ||
| 15 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 16 | integer, intent(in) :: idx1, idx2 | ||
| 17 | type(array_type), pointer :: c | ||
| 18 | |||
| 19 | ! Local variables | ||
| 20 | integer :: i | ||
| 21 | |||
| 22 | − | c => a(1)%array(idx1, idx2) + a(2)%array(idx1, idx2) | |
| 23 | − | do i = 3, size(a), 1 | |
| 24 | − | c => c + a(i)%array(idx1, idx2) | |
| 25 | end do | ||
| 26 | − | end function add_array_ptr | |
| 27 | !############################################################################### | ||
| 28 | |||
| 29 | |||
| 30 | !############################################################################### | ||
| 31 | − | module function concat_array_ptr(a, idx1, idx2, dim) result(c) | |
| 32 | !! Concatenate two autodiff arrays along a specified dimension | ||
| 33 | implicit none | ||
| 34 | |||
| 35 | ! Arguments | ||
| 36 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 37 | integer, intent(in) :: idx1, idx2, dim | ||
| 38 | type(array_type), pointer :: c | ||
| 39 | |||
| 40 | ! Local variables | ||
| 41 | integer :: i | ||
| 42 | |||
| 43 | − | c => concat(a(1)%array(idx1, idx2), a(2)%array(idx1, idx2), dim) | |
| 44 | − | do i = 3, size(a), 1 | |
| 45 | − | c => concat(c, a(i)%array(idx1, idx2), dim) | |
| 46 | end do | ||
| 47 | − | end function concat_array_ptr | |
| 48 | !############################################################################### | ||
| 49 | |||
| 50 | |||
| 51 | !##############################################################################! | ||
| 52 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 53 | !##############################################################################! | ||
| 54 | |||
| 55 | |||
| 56 | !############################################################################### | ||
| 57 | − | module function add_bias(input, bias, dim, dim_act_on_shape) result(output) | |
| 58 | !! Add bias to input array along specified dimension | ||
| 59 | implicit none | ||
| 60 | |||
| 61 | ! Arguments | ||
| 62 | class(array_type), intent(in), target :: input | ||
| 63 | class(array_type), intent(in), target :: bias | ||
| 64 | integer, intent(in) :: dim | ||
| 65 | logical, intent(in), optional :: dim_act_on_shape | ||
| 66 | type(array_type), pointer :: output | ||
| 67 | |||
| 68 | ! Local variables | ||
| 69 | integer :: i, j, k, s, idx, itmp1 | ||
| 70 | integer :: num_elements_pre, num_elements_post, num_dims | ||
| 71 | logical :: dim_act_on_shape_ | ||
| 72 | |||
| 73 | − | if(present(dim_act_on_shape))then | |
| 74 | − | dim_act_on_shape_ = dim_act_on_shape | |
| 75 | else | ||
| 76 | − | dim_act_on_shape_ = .false. | |
| 77 | end if | ||
| 78 | |||
| 79 | − | output => input%create_result() | |
| 80 | − | allocate(output%indices(2)) | |
| 81 | − | output%indices(1) = dim | |
| 82 | − | if(dim_act_on_shape_)then | |
| 83 | − | num_dims = size(input%shape) | |
| 84 | − | if(dim .gt. num_dims) then | |
| 85 | − | call stop_program("Dimension for add_bias exceeds input dimensions") | |
| 86 | − | return | |
| 87 | − | elseif(size(bias%shape) .ne. 1)then | |
| 88 | − | call stop_program("Bias must be a 1D array") | |
| 89 | − | return | |
| 90 | end if | ||
| 91 | − | num_elements_pre = 1 | |
| 92 | − | num_elements_post = 1 | |
| 93 | − | do i = 1, num_dims | |
| 94 | − | if(i .lt. dim)then | |
| 95 | − | num_elements_pre = num_elements_pre * input%shape(i) | |
| 96 | − | elseif(i .gt. dim)then | |
| 97 | − | num_elements_post = num_elements_post * input%shape(i) | |
| 98 | end if | ||
| 99 | end do | ||
| 100 | |||
| 101 | − | itmp1 = num_elements_pre * input%shape(dim) | |
| 102 | − | do s = 1, size(input%val, 2) | |
| 103 | − | do k = 1, num_elements_post | |
| 104 | − | do j = 1, bias%shape(1) | |
| 105 | − | idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 | |
| 106 | − | do i = 1, num_elements_pre | |
| 107 | − | output%val(idx + i, s) = input%val(idx + i, s) + bias%val(j,1) | |
| 108 | end do | ||
| 109 | end do | ||
| 110 | end do | ||
| 111 | end do | ||
| 112 | − | output%indices(2) = 1 | |
| 113 | else | ||
| 114 | − | call stop_program("add_bias: dim_act_on_shape=.false. not implemented yet") | |
| 115 | − | output%indices(2) = 0 | |
| 116 | end if | ||
| 117 | |||
| 118 | − | output%get_partial_left => get_partial_add | |
| 119 | − | output%get_partial_right => get_partial_add_bias | |
| 120 | − | output%get_partial_left_val => get_partial_add_val | |
| 121 | − | output%get_partial_right_val => get_partial_add_bias_val | |
| 122 | − | if(input%requires_grad .or. bias%requires_grad)then | |
| 123 | − | output%requires_grad = .true. | |
| 124 | − | output%is_forward = input%is_forward .or. bias%is_forward | |
| 125 | − | output%operation = 'add_bias' | |
| 126 | − | output%left_operand => input | |
| 127 | − | output%right_operand => bias | |
| 128 | end if | ||
| 129 | |||
| 130 | − | end function add_bias | |
| 131 | !------------------------------------------------------------------------------- | ||
| 132 | − | function get_partial_add(this, upstream_grad) result(output) | |
| 133 | !! Get partial derivative with respect to left operand | ||
| 134 | implicit none | ||
| 135 | class(array_type), intent(inout) :: this | ||
| 136 | type(array_type), intent(in) :: upstream_grad | ||
| 137 | type(array_type) :: output | ||
| 138 | |||
| 139 | − | output = upstream_grad | |
| 140 | − | end function get_partial_add | |
| 141 | !------------------------------------------------------------------------------- | ||
| 142 | − | pure subroutine get_partial_add_val(this, upstream_grad, output) | |
| 143 | !! Get partial derivative with respect to left operand | ||
| 144 | implicit none | ||
| 145 | class(array_type), intent(in) :: this | ||
| 146 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 147 | real(real32), dimension(:,:), intent(out) :: output | ||
| 148 | |||
| 149 | − | if(size(upstream_grad,2).ne.size(output,2))then | |
| 150 | − | if(size(output,1).eq.1)then | |
| 151 | − | output(1,1) = sum(upstream_grad) | |
| 152 | else | ||
| 153 | − | output(:,1) = sum(upstream_grad, dim=2) | |
| 154 | end if | ||
| 155 | else | ||
| 156 | − | if(size(output,1).eq.1.and.size(output,1).ne.size(upstream_grad,1))then | |
| 157 | − | output(1,:) = sum(upstream_grad,1) | |
| 158 | else | ||
| 159 | − | output = upstream_grad | |
| 160 | end if | ||
| 161 | end if | ||
| 162 | − | end subroutine get_partial_add_val | |
| 163 | !------------------------------------------------------------------------------- | ||
| 164 | − | function get_partial_add_bias(this, upstream_grad) result(output) | |
| 165 | !! Get partial derivative with respect to bias operand | ||
| 166 | implicit none | ||
| 167 | class(array_type), intent(inout) :: this | ||
| 168 | type(array_type), intent(in) :: upstream_grad | ||
| 169 | type(array_type) :: output | ||
| 170 | |||
| 171 | − | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 172 | − | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 173 | |||
| 174 | − | end function get_partial_add_bias | |
| 175 | !------------------------------------------------------------------------------- | ||
| 176 | − | pure subroutine get_partial_add_bias_val(this, upstream_grad, output) | |
| 177 | implicit none | ||
| 178 | |||
| 179 | ! Arguments | ||
| 180 | class(array_type), intent(in) :: this | ||
| 181 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 182 | real(real32), dimension(:,:), intent(out) :: output | ||
| 183 | |||
| 184 | integer :: i, j, k, s, idx, itmp1 | ||
| 185 | integer :: num_elements_pre, num_elements_post, num_dims | ||
| 186 | |||
| 187 | − | num_dims = size(this%left_operand%shape) | |
| 188 | − | num_elements_pre = 1 | |
| 189 | − | num_elements_post = 1 | |
| 190 | − | do i = 1, num_dims | |
| 191 | − | if(i .lt. this%indices(1))then | |
| 192 | − | num_elements_pre = num_elements_pre * this%left_operand%shape(i) | |
| 193 | − | elseif(i .gt. this%indices(1))then | |
| 194 | − | num_elements_post = num_elements_post * this%left_operand%shape(i) | |
| 195 | end if | ||
| 196 | end do | ||
| 197 | |||
| 198 | − | itmp1 = num_elements_pre * this%left_operand%shape(this%indices(1)) | |
| 199 | − | output = 0._real32 | |
| 200 | − | do s = 1, size(upstream_grad, 2) | |
| 201 | − | do k = 1, num_elements_post | |
| 202 | − | do j = 1, this%right_operand%shape(1) | |
| 203 | − | idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 | |
| 204 | − | do i = 1, num_elements_pre | |
| 205 | − | output(j,1) = output(j,1) + upstream_grad(idx + i, s) | |
| 206 | end do | ||
| 207 | end do | ||
| 208 | end do | ||
| 209 | end do | ||
| 210 | |||
| 211 | − | end subroutine get_partial_add_bias_val | |
| 212 | !############################################################################### | ||
| 213 | |||
| 214 | |||
| 215 | !############################################################################### | ||
| 216 | − | module function piecewise_array(input, gradient, limit) result(output) | |
| 217 | !! Apply piecewise activation function to input array | ||
| 218 | implicit none | ||
| 219 | |||
| 220 | ! Arguments | ||
| 221 | class(array_type), intent(in), target :: input | ||
| 222 | real(real32), intent(in) :: gradient | ||
| 223 | real(real32), intent(in) :: limit | ||
| 224 | type(array_type), pointer :: output | ||
| 225 | type(array_type), pointer :: b_array | ||
| 226 | |||
| 227 | − | output => input%create_result() | |
| 228 | − | where(input%val.ge.limit) | |
| 229 | − | output%val = gradient * (input%val - limit) + limit | |
| 230 | − | elsewhere(input%val.le.-limit) | |
| 231 | − | output%val = gradient * (input%val + limit) - limit | |
| 232 | elsewhere | ||
| 233 | − | output%val = input%val | |
| 234 | end where | ||
| 235 | |||
| 236 | − | output%get_partial_left => get_partial_piecewise | |
| 237 | − | output%get_partial_left_val => get_partial_piecewise_val | |
| 238 | − | if(input%requires_grad)then | |
| 239 | − | output%requires_grad = .true. | |
| 240 | − | output%is_forward = input%is_forward | |
| 241 | − | output%operation = 'piecewise' | |
| 242 | − | output%left_operand => input | |
| 243 | − | output%owns_left_operand = input%is_temporary | |
| 244 | end if | ||
| 245 | − | allocate(b_array) | |
| 246 | − | b_array%is_sample_dependent = .false. | |
| 247 | − | b_array%requires_grad = .false. | |
| 248 | − | call b_array%allocate(array_shape=[2, 1]) | |
| 249 | − | b_array%val(1,1) = gradient | |
| 250 | − | b_array%val(2,1) = limit | |
| 251 | − | output%right_operand => b_array | |
| 252 | − | output%owns_right_operand = .true. | |
| 253 | |||
| 254 | − | end function piecewise_array | |
| 255 | !------------------------------------------------------------------------------- | ||
| 256 | − | function get_partial_piecewise(this, upstream_grad) result(output) | |
| 257 | !! Get partial derivative of piecewise activation | ||
| 258 | implicit none | ||
| 259 | class(array_type), intent(inout) :: this | ||
| 260 | type(array_type), intent(in) :: upstream_grad | ||
| 261 | type(array_type) :: output | ||
| 262 | |||
| 263 | type(array_type), pointer :: ptr | ||
| 264 | |||
| 265 | ptr => merge( & | ||
| 266 | upstream_grad, & | ||
| 267 | − | upstream_grad * this%right_operand%val(1,1), & | |
| 268 | − | this%left_operand%val.le.-this%right_operand%val(2,1) .or. & | |
| 269 | − | this%left_operand%val.ge.this%right_operand%val(2,1) & | |
| 270 | − | ) | |
| 271 | − | call output%assign_and_deallocate_source(ptr) | |
| 272 | |||
| 273 | − | end function get_partial_piecewise | |
| 274 | !------------------------------------------------------------------------------- | ||
| 275 | − | pure subroutine get_partial_piecewise_val(this, upstream_grad, output) | |
| 276 | !! Get partial derivative of piecewise activation (in-place version) | ||
| 277 | implicit none | ||
| 278 | class(array_type), intent(in) :: this | ||
| 279 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 280 | real(real32), dimension(:,:), intent(out) :: output | ||
| 281 | |||
| 282 | − | where(this%left_operand%val.le.this%right_operand%val(2,1) .or. & | |
| 283 | − | this%left_operand%val.ge.-this%right_operand%val(2,1) & | |
| 284 | − | ) | |
| 285 | − | output = upstream_grad | |
| 286 | elsewhere | ||
| 287 | − | output = upstream_grad * this%right_operand%val(1,1) | |
| 288 | end where | ||
| 289 | |||
| 290 | − | end subroutine get_partial_piecewise_val | |
| 291 | !############################################################################### | ||
| 292 | |||
| 293 | |||
| 294 | !############################################################################### | ||
| 295 | − | module function softmax_array(input, dim) result(output) | |
| 296 | implicit none | ||
| 297 | class(array_type), intent(in), target :: input | ||
| 298 | integer, intent(in) :: dim | ||
| 299 | type(array_type), pointer :: output | ||
| 300 | |||
| 301 | integer :: i | ||
| 302 | |||
| 303 | − | output => input%create_result() | |
| 304 | − | if(dim.eq.1)then | |
| 305 | − | do i = 1, size(input%val, 1) | |
| 306 | − | output%val(i, :) = exp(input%val(i, :) - maxval(input%val(i,:))) | |
| 307 | − | output%val(i, :) = output%val(i, :) / sum(output%val(i, :)) | |
| 308 | end do | ||
| 309 | − | elseif(dim.eq.2)then | |
| 310 | − | do i = 1, size(input%val, 2) | |
| 311 | − | output%val(:, i) = exp(input%val(:, i) - maxval(input%val(:, i))) | |
| 312 | − | output%val(:, i) = output%val(:, i) / sum(output%val(:, i)) | |
| 313 | end do | ||
| 314 | else | ||
| 315 | − | call stop_program("softmax_array: Unsupported dimension") | |
| 316 | end if | ||
| 317 | − | allocate(output%indices(1)) | |
| 318 | − | output%indices(1) = dim | |
| 319 | |||
| 320 | − | output%get_partial_left => get_partial_softmax | |
| 321 | − | output%get_partial_left_val => get_partial_softmax_val | |
| 322 | − | if(input%requires_grad)then | |
| 323 | − | output%requires_grad = .true. | |
| 324 | − | output%is_forward = input%is_forward | |
| 325 | − | output%operation = 'softmax' | |
| 326 | − | output%left_operand => input | |
| 327 | − | output%owns_left_operand = input%is_temporary | |
| 328 | end if | ||
| 329 | |||
| 330 | − | end function softmax_array | |
| 331 | !------------------------------------------------------------------------------- | ||
| 332 | − | function get_partial_softmax(this, upstream_grad) result(output) | |
| 333 | !! Get partial derivative of softmax activation | ||
| 334 | implicit none | ||
| 335 | class(array_type), intent(inout) :: this | ||
| 336 | type(array_type), intent(in) :: upstream_grad | ||
| 337 | type(array_type) :: output | ||
| 338 | type(array_type), pointer :: ptr | ||
| 339 | |||
| 340 | integer :: dim | ||
| 341 | |||
| 342 | − | if(this%indices(1).eq.1)then | |
| 343 | − | dim = 2 | |
| 344 | else | ||
| 345 | − | dim = 1 | |
| 346 | end if | ||
| 347 | ! ptr => this * upstream_grad | ||
| 348 | ! ptr => ptr - this * sum(ptr, dim=dim) | ||
| 349 | − | ptr => softmax_reverse_array(this, upstream_grad, this%indices(1)) | |
| 350 | − | call output%assign_and_deallocate_source(ptr) | |
| 351 | |||
| 352 | − | end function get_partial_softmax | |
| 353 | !------------------------------------------------------------------------------- | ||
| 354 | − | pure subroutine get_partial_softmax_val(this, upstream_grad, output) | |
| 355 | !! Get partial derivative of softmax activation (in-place version) | ||
| 356 | implicit none | ||
| 357 | class(array_type), intent(in) :: this | ||
| 358 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 359 | real(real32), dimension(:,:), intent(out) :: output | ||
| 360 | |||
| 361 | integer :: s, dim | ||
| 362 | |||
| 363 | − | if(this%indices(1).eq.1)then | |
| 364 | − | dim = 2 | |
| 365 | else | ||
| 366 | − | dim = 1 | |
| 367 | end if | ||
| 368 | − | output = this%val * upstream_grad | |
| 369 | − | if(dim.eq.1)then | |
| 370 | − | do s = 1, size(this%val, 2) | |
| 371 | − | output(:, s) = output(:, s) - this%val(:, s) * sum(output(:, s)) | |
| 372 | end do | ||
| 373 | − | elseif(dim.eq.2)then | |
| 374 | − | do s = 1, size(this%val, 1) | |
| 375 | − | output(s, :) = output(s, :) - this%val(s, :) * sum(output(s, :)) | |
| 376 | end do | ||
| 377 | end if | ||
| 378 | − | end subroutine get_partial_softmax_val | |
| 379 | !############################################################################### | ||
| 380 | |||
| 381 | |||
| 382 | !############################################################################### | ||
| 383 | − | module function swish_array(input, beta) result(output) | |
| 384 | !! Swish activation function | ||
| 385 | implicit none | ||
| 386 | |||
| 387 | ! Arguments | ||
| 388 | class(array_type), intent(in), target :: input | ||
| 389 | real(real32), intent(in) :: beta | ||
| 390 | type(array_type), pointer :: output | ||
| 391 | type(array_type), pointer :: b_array | ||
| 392 | |||
| 393 | − | output => input%create_result() | |
| 394 | − | output%val = input%val * (1._real32 / (1._real32 + exp(-beta * input%val))) | |
| 395 | |||
| 396 | − | output%get_partial_left => get_partial_swish | |
| 397 | − | output%get_partial_left_val => get_partial_swish_val | |
| 398 | − | if(input%requires_grad)then | |
| 399 | − | output%requires_grad = .true. | |
| 400 | − | output%is_forward = input%is_forward | |
| 401 | − | output%operation = 'swish' | |
| 402 | − | output%left_operand => input | |
| 403 | − | output%owns_left_operand = input%is_temporary | |
| 404 | end if | ||
| 405 | − | allocate(b_array) | |
| 406 | − | b_array%is_sample_dependent = .false. | |
| 407 | − | b_array%is_scalar = .true. | |
| 408 | − | b_array%requires_grad = .false. | |
| 409 | − | call b_array%allocate(array_shape=[1, 1]) | |
| 410 | − | b_array%val(1,1) = beta | |
| 411 | − | output%right_operand => b_array | |
| 412 | − | output%owns_right_operand = .true. | |
| 413 | |||
| 414 | − | end function swish_array | |
| 415 | !------------------------------------------------------------------------------- | ||
| 416 | − | function get_partial_swish(this, upstream_grad) result(output) | |
| 417 | !! Get partial derivative of swish activation | ||
| 418 | implicit none | ||
| 419 | class(array_type), intent(inout) :: this | ||
| 420 | type(array_type), intent(in) :: upstream_grad | ||
| 421 | type(array_type) :: output | ||
| 422 | |||
| 423 | type(array_type), pointer :: ptr | ||
| 424 | type(array_type), pointer :: exp_term | ||
| 425 | |||
| 426 | − | exp_term => exp(this%right_operand%val(1,1) * this%left_operand) | |
| 427 | |||
| 428 | ptr => upstream_grad * exp_term * ( & | ||
| 429 | − | this%right_operand%val(1,1) * this%left_operand + & | |
| 430 | exp_term + 1._real32 & | ||
| 431 | − | ) / ( ( exp_term + 1._real32 )**2._real32 ) | |
| 432 | |||
| 433 | − | call output%assign_and_deallocate_source(ptr) | |
| 434 | − | end function get_partial_swish | |
| 435 | !------------------------------------------------------------------------------- | ||
| 436 | − | pure subroutine get_partial_swish_val(this, upstream_grad, output) | |
| 437 | !! Get partial derivative of swish activation (in-place version) | ||
| 438 | implicit none | ||
| 439 | class(array_type), intent(in) :: this | ||
| 440 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 441 | real(real32), dimension(:,:), intent(out) :: output | ||
| 442 | |||
| 443 | − | real(real32), dimension(size(this%val,1), size(this%val,2)) :: exp_term | |
| 444 | |||
| 445 | − | exp_term = exp(this%right_operand%val(1,1) * this%left_operand%val) | |
| 446 | − | output = upstream_grad * exp_term * ( & | |
| 447 | − | this%right_operand%val(1,1) * this%left_operand%val + & | |
| 448 | − | exp_term + 1._real32 & | |
| 449 | − | ) / ( ( exp_term + 1._real32 )**2._real32 ) | |
| 450 | |||
| 451 | − | end subroutine get_partial_swish_val | |
| 452 | !############################################################################### | ||
| 453 | |||
| 454 | |||
| 455 | !############################################################################### | ||
| 456 | − | function softmax_reverse_array(softmax, gradient, dim) result(output) | |
| 457 | !! Softmax function for reverse mode autodiff | ||
| 458 | implicit none | ||
| 459 | class(array_type), intent(in), target :: softmax | ||
| 460 | class(array_type), intent(in), target :: gradient | ||
| 461 | integer, intent(in) :: dim | ||
| 462 | type(array_type), pointer :: output | ||
| 463 | |||
| 464 | integer :: i | ||
| 465 | − | real(real32), dimension(size(softmax%val,1), size(softmax%val,2)) :: temp_val | |
| 466 | |||
| 467 | |||
| 468 | − | output => softmax%create_result() | |
| 469 | − | temp_val = gradient%val * softmax%val | |
| 470 | − | if(dim.eq.1)then | |
| 471 | − | do concurrent(i=1:size(softmax%val,1)) | |
| 472 | − | temp_val(i, :) = temp_val(i, :) - softmax%val(i, :) * sum(temp_val(i, :)) | |
| 473 | end do | ||
| 474 | − | elseif(dim.eq.2)then | |
| 475 | − | do concurrent(i=1:size(softmax%val,2)) | |
| 476 | − | temp_val(:, i) = temp_val(:, i) - softmax%val(:, i) * sum(temp_val(:, i)) | |
| 477 | end do | ||
| 478 | else | ||
| 479 | − | call stop_program("softmax_reverse_array: Unsupported dimension") | |
| 480 | end if | ||
| 481 | − | output%val = temp_val | |
| 482 | − | output%indices = [dim] | |
| 483 | |||
| 484 | − | output%get_partial_left => get_partial_softmax_reverse_left | |
| 485 | − | output%get_partial_left_val => get_partial_softmax_reverse_left_val | |
| 486 | − | output%get_partial_right => get_partial_softmax_reverse_right | |
| 487 | − | output%get_partial_right_val => get_partial_softmax_reverse_right_val | |
| 488 | − | if(softmax%requires_grad .or. gradient%requires_grad)then | |
| 489 | − | output%requires_grad = .true. | |
| 490 | − | output%is_forward = softmax%is_forward .or. gradient%is_forward | |
| 491 | − | output%operation = 'softmax_reverse' | |
| 492 | − | output%left_operand => softmax | |
| 493 | − | output%right_operand => gradient | |
| 494 | − | output%owns_left_operand = softmax%is_temporary | |
| 495 | − | output%owns_right_operand = gradient%is_temporary | |
| 496 | end if | ||
| 497 | |||
| 498 | − | end function softmax_reverse_array | |
| 499 | !------------------------------------------------------------------------------- | ||
| 500 | − | function get_partial_softmax_reverse_left(this, upstream_grad) result(output) | |
| 501 | !! Get partial derivative of softmax reverse operation | ||
| 502 | implicit none | ||
| 503 | class(array_type), intent(inout) :: this | ||
| 504 | type(array_type), intent(in) :: upstream_grad | ||
| 505 | type(array_type) :: output | ||
| 506 | |||
| 507 | type(array_type), pointer :: sum_yg, sum_yu | ||
| 508 | type(array_type), pointer :: ptr | ||
| 509 | |||
| 510 | − | sum_yg => sum(this%left_operand * this%right_operand, dim=this%indices(1)) | |
| 511 | − | sum_yu => sum(this%left_operand * upstream_grad, dim=this%indices(1)) | |
| 512 | |||
| 513 | − | ptr => upstream_grad * (this%right_operand - sum_yg) - this%right_operand * sum_yu | |
| 514 | − | call output%assign_and_deallocate_source(ptr) | |
| 515 | |||
| 516 | − | end function get_partial_softmax_reverse_left | |
| 517 | !------------------------------------------------------------------------------- | ||
| 518 | − | function get_partial_softmax_reverse_right(this, upstream_grad) result(output) | |
| 519 | !! Get partial derivative of softmax reverse operation | ||
| 520 | implicit none | ||
| 521 | class(array_type), intent(inout) :: this | ||
| 522 | type(array_type), intent(in) :: upstream_grad | ||
| 523 | type(array_type) :: output | ||
| 524 | |||
| 525 | type(array_type), pointer :: ptr | ||
| 526 | |||
| 527 | ptr => ( & | ||
| 528 | upstream_grad - & | ||
| 529 | − | sum(this%left_operand * upstream_grad, dim=this%indices(1)) & | |
| 530 | − | ) * this%left_operand | |
| 531 | − | call output%assign_and_deallocate_source(ptr) | |
| 532 | |||
| 533 | − | end function get_partial_softmax_reverse_right | |
| 534 | !------------------------------------------------------------------------------- | ||
| 535 | − | pure subroutine get_partial_softmax_reverse_left_val(this, upstream_grad, output) | |
| 536 | !! Get partial derivative of softmax reverse operation (in-place version) | ||
| 537 | implicit none | ||
| 538 | class(array_type), intent(in) :: this | ||
| 539 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 540 | real(real32), dimension(:,:), intent(out) :: output | ||
| 541 | |||
| 542 | integer :: dim, i | ||
| 543 | − | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yg | |
| 544 | − | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu | |
| 545 | |||
| 546 | − | dim = this%indices(1) | |
| 547 | − | sum_yg = sum(this%left_operand%val * this%right_operand%val, dim=dim) | |
| 548 | − | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 549 | |||
| 550 | − | if(dim.eq.1)then | |
| 551 | − | do concurrent(i=1:size(this%val,2)) | |
| 552 | − | output(:, i) = & | |
| 553 | − | upstream_grad(:, i) * (this%right_operand%val(:, i) - sum_yg(i)) - & | |
| 554 | − | this%right_operand%val(:, i) * sum_yu(i) | |
| 555 | end do | ||
| 556 | else | ||
| 557 | − | do concurrent(i=1:size(this%val,1)) | |
| 558 | − | output(i, :) = & | |
| 559 | − | upstream_grad(i, :) * (this%right_operand%val(i, :) - sum_yg(i)) - & | |
| 560 | − | this%right_operand%val(i, :) * sum_yu(i) | |
| 561 | end do | ||
| 562 | end if | ||
| 563 | |||
| 564 | − | end subroutine get_partial_softmax_reverse_left_val | |
| 565 | !------------------------------------------------------------------------------- | ||
| 566 | − | pure subroutine get_partial_softmax_reverse_right_val(this, upstream_grad, output) | |
| 567 | !! Get partial derivative of softmax reverse operation (in-place version) | ||
| 568 | implicit none | ||
| 569 | class(array_type), intent(in) :: this | ||
| 570 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 571 | real(real32), dimension(:,:), intent(out) :: output | ||
| 572 | |||
| 573 | integer :: dim, i | ||
| 574 | − | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu | |
| 575 | |||
| 576 | − | dim = this%indices(1) | |
| 577 | − | if(dim.eq.1)then | |
| 578 | − | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 579 | − | do concurrent(i=1:size(this%val,1)) | |
| 580 | − | output(i, :) = upstream_grad(i, :) - sum_yu(i) * this%left_operand%val(i, :) | |
| 581 | end do | ||
| 582 | else | ||
| 583 | − | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 584 | − | do concurrent(i=1:size(this%val,2)) | |
| 585 | − | output(:, i) = upstream_grad(:, i) - sum_yu(i) * this%left_operand%val(:, i) | |
| 586 | end do | ||
| 587 | end if | ||
| 588 | − | end subroutine get_partial_softmax_reverse_right_val | |
| 589 | !############################################################################### | ||
| 590 | |||
| 591 | end submodule athena__diffstruc_extd_submodule | ||
| 592 |