| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_merge | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | − | module function merge_scalar_over_channels(tsource, fsource, mask) result(output) | |
| 8 | !! 1D average pooling operation | ||
| 9 | implicit none | ||
| 10 | |||
| 11 | ! Arguments | ||
| 12 | class(array_type), intent(in), target :: tsource | ||
| 13 | real(real32), intent(in) :: fsource | ||
| 14 | logical, dimension(:,:), intent(in) :: mask | ||
| 15 | type(array_type), pointer :: output | ||
| 16 | |||
| 17 | ! Local variables | ||
| 18 | integer :: i, m, s | ||
| 19 | integer :: num_elements, num_dims | ||
| 20 | |||
| 21 | |||
| 22 | − | output => tsource%create_result() | |
| 23 | − | num_dims = size(tsource%shape) | |
| 24 | − | num_elements = product(tsource%shape(1:num_dims - 1)) | |
| 25 | − | do concurrent(s = 1:size(tsource%val,2), m = 1: tsource%shape(num_dims)) | |
| 26 | − | do concurrent(i=1:num_elements) | |
| 27 | − | if(mask(i,1)) then | |
| 28 | − | output%val(i + (m-1) * num_elements,s) = tsource%val(i,s) | |
| 29 | else | ||
| 30 | − | output%val(i + (m-1) * num_elements,s) = fsource | |
| 31 | end if | ||
| 32 | end do | ||
| 33 | end do | ||
| 34 | − | output%mask = mask | |
| 35 | |||
| 36 | − | output%get_partial_left => get_partial_merge_scalar_over_channels | |
| 37 | − | output%get_partial_left_val => get_partial_merge_scalar_over_channels_val | |
| 38 | − | if(tsource%requires_grad) then | |
| 39 | − | output%requires_grad = .true. | |
| 40 | − | output%is_forward = tsource%is_forward | |
| 41 | − | output%operation = 'merge_over_channels' | |
| 42 | − | output%left_operand => tsource | |
| 43 | end if | ||
| 44 | |||
| 45 | − | end function merge_scalar_over_channels | |
| 46 | !------------------------------------------------------------------------------- | ||
| 47 | − | function get_partial_merge_scalar_over_channels(this, upstream_grad) result(output) | |
| 48 | implicit none | ||
| 49 | class(array_type), intent(inout) :: this | ||
| 50 | type(array_type), intent(in) :: upstream_grad | ||
| 51 | type(array_type) :: output | ||
| 52 | |||
| 53 | − | output = merge_scalar_over_channels(upstream_grad, 0._real32, this%mask) | |
| 54 | |||
| 55 | − | end function get_partial_merge_scalar_over_channels | |
| 56 | !------------------------------------------------------------------------------- | ||
| 57 | − | pure subroutine get_partial_merge_scalar_over_channels_val( & | |
| 58 | − | this, upstream_grad, output & | |
| 59 | ) | ||
| 60 | implicit none | ||
| 61 | class(array_type), intent(in) :: this | ||
| 62 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 63 | real(real32), dimension(:,:), intent(out) :: output | ||
| 64 | |||
| 65 | integer :: i, m, s | ||
| 66 | integer :: num_elements, num_dims, num_channels | ||
| 67 | |||
| 68 | − | num_dims = size(this%left_operand%shape) | |
| 69 | − | num_elements = product(this%left_operand%shape(1:num_dims - 1)) | |
| 70 | − | num_channels = this%left_operand%shape(num_dims) | |
| 71 | |||
| 72 | − | do concurrent(s = 1:size(upstream_grad,2), m = 1: num_channels) | |
| 73 | − | do concurrent(i=1:num_elements) | |
| 74 | − | if(this%mask(i,1)) then | |
| 75 | − | output(i + (m-1) * num_elements,s) = upstream_grad(i,s) | |
| 76 | else | ||
| 77 | − | output(i + (m-1) * num_elements,s) = 0._real32 | |
| 78 | end if | ||
| 79 | end do | ||
| 80 | end do | ||
| 81 | |||
| 82 | − | end subroutine get_partial_merge_scalar_over_channels_val | |
| 83 | !############################################################################### | ||
| 84 | |||
| 85 | end submodule athena__diffstruc_extd_submodule_merge | ||
| 86 |