| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_conv | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | 6 | module function conv1d(input, kernel, stride, dilation) result(output) | |
| 8 | !! 1D convolution operation | ||
| 9 | implicit none | ||
| 10 | |||
| 11 | ! Arguments | ||
| 12 | type(array_type), intent(in), target :: input | ||
| 13 | type(array_type), intent(in), target :: kernel | ||
| 14 | integer, intent(in) :: stride | ||
| 15 | integer, intent(in) :: dilation | ||
| 16 | type(array_type), pointer :: output | ||
| 17 | |||
| 18 | ! Local variables | ||
| 19 | integer :: i, k, c_in, c_out, s | ||
| 20 | integer :: i_in, k_idx | ||
| 21 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 22 | real(real32) :: conv_sum | ||
| 23 | integer, dimension(3) :: output_shape | ||
| 24 | |||
| 25 | ! Extract dimensions | ||
| 26 | ! input: [H_in, C_in, B] | ||
| 27 | ! kernel: [K, C_in, C_out] | ||
| 28 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | input_h = input%shape(1) |
| 29 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | num_channels = input%shape(2) |
| 30 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | kernel_h = kernel%shape(1) |
| 31 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | num_filters = kernel%shape(3) |
| 32 | |||
| 33 | ! Calculate output dimensions | ||
| 34 | output_h = (input_h - dilation*(kernel_h - 1) - 1) / & | ||
| 35 | 6 | stride + 1 | |
| 36 |
2/2✓ Branch 0 taken 18 times.
✓ Branch 1 taken 6 times.
|
24 | output_shape = [output_h, num_filters, size(input%val, dim=2)] |
| 37 | |||
| 38 | 6 | output => input%create_result(array_shape = output_shape) | |
| 39 |
4/4✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 100 times.
✓ Branch 3 taken 6 times.
|
112 | output%val = 0._real32 |
| 40 | |||
| 41 | ! Perform convolution | ||
| 42 | do concurrent(s = 1:output_shape(3), c_out = 1:num_filters, & | ||
| 43 | 6 | i = 1:output_h) | |
| 44 | 100 | conv_sum = 0._real32 | |
| 45 |
2/2✓ Branch 0 taken 340 times.
✓ Branch 1 taken 100 times.
|
440 | do c_in = 1, num_channels |
| 46 |
2/2✓ Branch 0 taken 1209 times.
✓ Branch 1 taken 340 times.
|
1649 | do k = 1, kernel_h |
| 47 | 1209 | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 48 |
1/2✓ Branch 0 taken 1209 times.
✗ Branch 1 not taken.
|
1549 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 49 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 50 | 1209 | ( c_out - 1 ) * kernel_h * num_channels | |
| 51 | conv_sum = conv_sum + & | ||
| 52 | ✗ | input%val(i_in + ( c_in - 1 ) * input_h, s) * & | |
| 53 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 1209 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1209 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1209 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1209 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1209 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1209 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1209 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1209 times.
|
1209 | kernel%val(k_idx, 1) |
| 54 | end if | ||
| 55 | end do | ||
| 56 | end do | ||
| 57 |
10/14✓ Branch 0 taken 22 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 100 times.
✓ Branch 3 taken 22 times.
✓ Branch 4 taken 100 times.
✓ Branch 5 taken 100 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 100 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 100 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 100 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 100 times.
|
228 | output%val(i + (c_out-1)*output_h, s) = conv_sum |
| 58 | end do | ||
| 59 | |||
| 60 | ! Store parameters for backward pass | ||
| 61 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
|
6 | allocate(output%indices(2)) |
| 62 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | output%indices(1) = num_channels |
| 63 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | output%indices(2) = num_filters |
| 64 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
|
6 | allocate(output%adj_ja(1,3)) |
| 65 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
|
6 | output%adj_ja(1,1) = stride |
| 66 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
|
6 | output%adj_ja(1,2) = dilation |
| 67 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
|
6 | output%adj_ja(1,3) = kernel_h |
| 68 | |||
| 69 | 6 | output%get_partial_left => get_partial_conv1d_input | |
| 70 | 6 | output%get_partial_right => get_partial_conv1d_kernel | |
| 71 | 6 | output%get_partial_left_val => get_partial_conv1d_input_val | |
| 72 | 6 | output%get_partial_right_val => get_partial_conv1d_kernel_val | |
| 73 |
1/2✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
|
6 | if(input%requires_grad .or. kernel%requires_grad)then |
| 74 | 6 | output%requires_grad = .true. | |
| 75 | 6 | output%is_forward = input%is_forward | |
| 76 | 6 | output%operation = 'conv1d' | |
| 77 | 6 | output%left_operand => input | |
| 78 | 6 | output%right_operand => kernel | |
| 79 | end if | ||
| 80 | |||
| 81 | 6 | end function conv1d | |
| 82 | !------------------------------------------------------------------------------- | ||
| 83 | ✗ | function get_partial_conv1d_input(this, upstream_grad) result(output) | |
| 84 | !! Get partial derivative wrt input for 1D convolution | ||
| 85 | implicit none | ||
| 86 | |||
| 87 | class(array_type), intent(inout) :: this | ||
| 88 | type(array_type), intent(in) :: upstream_grad | ||
| 89 | type(array_type) :: output | ||
| 90 | |||
| 91 | ✗ | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 92 | ✗ | size(this%left_operand%val, dim=2) ]) | |
| 93 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 94 | |||
| 95 | ✗ | end function get_partial_conv1d_input | |
| 96 | !------------------------------------------------------------------------------- | ||
| 97 | ✗ | function get_partial_conv1d_kernel(this, upstream_grad) result(output) | |
| 98 | !! Get partial derivative wrt kernel for 1D convolution | ||
| 99 | implicit none | ||
| 100 | |||
| 101 | class(array_type), intent(inout) :: this | ||
| 102 | type(array_type), intent(in) :: upstream_grad | ||
| 103 | type(array_type) :: output | ||
| 104 | |||
| 105 | ✗ | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 106 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 107 | |||
| 108 | ✗ | end function get_partial_conv1d_kernel | |
| 109 | !------------------------------------------------------------------------------- | ||
| 110 |
2/4✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
|
2 | pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output) |
| 111 | !! Get partial derivative wrt input for 1D convolution (subroutine version) | ||
| 112 | implicit none | ||
| 113 | |||
| 114 | class(array_type), intent(in) :: this | ||
| 115 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 116 | real(real32), dimension(:,:), intent(out) :: output | ||
| 117 | |||
| 118 | ! Local variables | ||
| 119 | integer :: i, k, c_in, c_out, s | ||
| 120 | integer :: i_in, k_idx, out_idx | ||
| 121 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 122 | integer :: stride, dilation | ||
| 123 | real(real32) :: grad_val, kernel_val | ||
| 124 | |||
| 125 | |||
| 126 | ! Unpack parameters | ||
| 127 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | num_channels = this%indices(1) |
| 128 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | num_filters = this%indices(2) |
| 129 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | stride = this%adj_ja(1,1) |
| 130 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | dilation = this%adj_ja(1,2) |
| 131 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | kernel_h = this%adj_ja(1,3) |
| 132 | |||
| 133 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | input_h = this%left_operand%shape(1) |
| 134 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | output_h = this%shape(1) |
| 135 | |||
| 136 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 2 times.
✓ Branch 20 taken 28 times.
✓ Branch 21 taken 2 times.
|
32 | output = 0._real32 |
| 137 | |||
| 138 | ! Parallelised over batch, channels, and output positions | ||
| 139 | ✗ | do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, & | |
| 140 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
|
2 | i = 1:output_h, c_out = 1:num_filters) |
| 141 | 99 | out_idx = i + (c_out-1)*output_h | |
| 142 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 99 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 99 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 99 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 99 times.
|
99 | grad_val = upstream_grad(out_idx, s) |
| 143 | |||
| 144 |
9/10✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 27 times.
✓ Branch 3 taken 9 times.
✓ Branch 4 taken 99 times.
✓ Branch 5 taken 27 times.
✓ Branch 6 taken 99 times.
✓ Branch 7 taken 99 times.
✓ Branch 8 taken 99 times.
✗ Branch 9 not taken.
|
335 | if(abs(grad_val) .gt. 1.e-30_real32)then |
| 145 |
2/2✓ Branch 0 taken 390 times.
✓ Branch 1 taken 99 times.
|
489 | do k = 1, kernel_h |
| 146 | 390 | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 147 |
1/2✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
|
489 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 148 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 149 | 390 | ( c_out - 1 ) * kernel_h * num_channels | |
| 150 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 390 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 390 times.
|
390 | kernel_val = this%right_operand%val(k_idx, 1) |
| 151 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
|
780 | output(i_in + ( c_in - 1 ) * input_h, s) = & |
| 152 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
|
780 | output(i_in + ( c_in - 1 ) * input_h, s) + & |
| 153 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 390 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 390 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 390 times.
|
1950 | grad_val * kernel_val |
| 154 | end if | ||
| 155 | end do | ||
| 156 | end if | ||
| 157 | end do | ||
| 158 | |||
| 159 | 2 | end subroutine get_partial_conv1d_input_val | |
| 160 | !------------------------------------------------------------------------------- | ||
| 161 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | pure subroutine get_partial_conv1d_kernel_val(this, upstream_grad, output) |
| 162 | !! Get partial derivative wrt kernel for 1D convolution (subroutine version) | ||
| 163 | implicit none | ||
| 164 | |||
| 165 | class(array_type), intent(in) :: this | ||
| 166 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 167 | real(real32), dimension(:,:), intent(out) :: output | ||
| 168 | |||
| 169 | ! Local variables | ||
| 170 | integer :: i, k, c_in, c_out, s | ||
| 171 | integer :: i_in, k_idx, out_idx | ||
| 172 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 173 | integer :: stride, dilation | ||
| 174 | real(real32) :: grad_sum | ||
| 175 | |||
| 176 | |||
| 177 | ! Unpack parameters | ||
| 178 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | num_channels = this%indices(1) |
| 179 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | num_filters = this%indices(2) |
| 180 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
|
3 | stride = this%adj_ja(1,1) |
| 181 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
|
3 | dilation = this%adj_ja(1,2) |
| 182 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
|
3 | kernel_h = this%adj_ja(1,3) |
| 183 | |||
| 184 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | input_h = this%left_operand%shape(1) |
| 185 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | output_h = this%shape(1) |
| 186 | |||
| 187 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 18 taken 3 times.
✓ Branch 19 taken 3 times.
✓ Branch 20 taken 166 times.
✓ Branch 21 taken 3 times.
|
172 | output = 0._real32 |
| 188 | |||
| 189 | ! Parallelised over filters, channels, and kernel positions | ||
| 190 | 3 | do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, k = 1:kernel_h) | |
| 191 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 192 | 166 | ( c_out - 1 ) * kernel_h * num_channels | |
| 193 | |||
| 194 | 166 | grad_sum = 0._real32 | |
| 195 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 166 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 166 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 166 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 166 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 166 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 166 times.
✓ Branch 18 taken 166 times.
✓ Branch 19 taken 166 times.
|
332 | do s = 1, size(upstream_grad, dim=2) |
| 196 |
2/2✓ Branch 0 taken 606 times.
✓ Branch 1 taken 166 times.
|
938 | do i = 1, output_h |
| 197 | 606 | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 198 |
1/2✓ Branch 0 taken 606 times.
✗ Branch 1 not taken.
|
772 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 199 | 606 | out_idx = i + ( c_out - 1 ) * output_h | |
| 200 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 606 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 606 times.
|
1212 | grad_sum = grad_sum + upstream_grad(out_idx, s) * & |
| 201 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 606 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 606 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 606 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 606 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 606 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 606 times.
|
1818 | this%left_operand%val(i_in + ( c_in - 1 ) * input_h, s) |
| 202 | end if | ||
| 203 | end do | ||
| 204 | end do | ||
| 205 |
9/12✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 27 times.
✓ Branch 3 taken 9 times.
✓ Branch 4 taken 166 times.
✓ Branch 5 taken 27 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 166 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 166 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 166 times.
|
205 | output(k_idx, 1) = grad_sum |
| 206 | end do | ||
| 207 | |||
| 208 | 3 | end subroutine get_partial_conv1d_kernel_val | |
| 209 | !############################################################################### | ||
| 210 | |||
| 211 | |||
| 212 | !############################################################################### | ||
| 213 | 18 | module function conv2d(input, kernel, stride, dilation) result(output) | |
| 214 | !! 2D convolution operation | ||
| 215 | implicit none | ||
| 216 | |||
| 217 | ! Arguments | ||
| 218 | type(array_type), intent(in), target :: input | ||
| 219 | type(array_type), intent(in), target :: kernel | ||
| 220 | integer, dimension(2), intent(in) :: stride | ||
| 221 | integer, dimension(2), intent(in) :: dilation | ||
| 222 | type(array_type), pointer :: output | ||
| 223 | |||
| 224 | ! Local variables | ||
| 225 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 226 | integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx | ||
| 227 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 228 | integer :: output_h, output_w, num_channels, num_filters | ||
| 229 | integer :: channel_size_in, channel_size_out, kernel_channel_size | ||
| 230 | integer :: dil_kernel_h_m1, dil_kernel_w_m1 | ||
| 231 | real(real32) :: conv_sum | ||
| 232 | integer, dimension(4) :: output_shape | ||
| 233 | |||
| 234 | ! Extract dimensions | ||
| 235 | ! input: [H_in, W_in, C_in, B] | ||
| 236 | ! kernel: [K_h, K_w, C_in, C_out] | ||
| 237 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | input_h = input%shape(1) |
| 238 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | input_w = input%shape(2) |
| 239 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | num_channels = input%shape(3) |
| 240 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | kernel_h = kernel%shape(1) |
| 241 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | kernel_w = kernel%shape(2) |
| 242 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | num_filters = kernel%shape(4) |
| 243 | |||
| 244 | ! Pre-compute common values | ||
| 245 | 18 | channel_size_in = input_h * input_w | |
| 246 | 18 | kernel_channel_size = kernel_h * kernel_w | |
| 247 | 18 | dil_kernel_h_m1 = dilation(1) * (kernel_h - 1) | |
| 248 | 18 | dil_kernel_w_m1 = dilation(2) * (kernel_w - 1) | |
| 249 | |||
| 250 | ! Calculate output dimensions | ||
| 251 | 18 | output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1 | |
| 252 | 18 | output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1 | |
| 253 | output_shape = [output_h, output_w, num_filters, & | ||
| 254 |
2/2✓ Branch 0 taken 72 times.
✓ Branch 1 taken 18 times.
|
90 | size(input%val, dim=2)] |
| 255 | |||
| 256 | 18 | output => input%create_result(array_shape = output_shape) | |
| 257 |
4/4✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 732 times.
✓ Branch 3 taken 18 times.
|
768 | output%val = 0._real32 |
| 258 | |||
| 259 | 18 | channel_size_out = output_h * output_w | |
| 260 | |||
| 261 | ! Perform convolution | ||
| 262 | do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, & | ||
| 263 | 18 | j = 1:output_w, i = 1:output_h) | |
| 264 | 732 | conv_sum = 0._real32 | |
| 265 |
2/2✓ Branch 0 taken 1740 times.
✓ Branch 1 taken 732 times.
|
2472 | do c_in = 1, num_channels |
| 266 | 1740 | in_base_idx = (c_in - 1) * channel_size_in | |
| 267 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 268 | 1740 | (c_out - 1) * kernel_channel_size * num_channels | |
| 269 |
2/2✓ Branch 0 taken 5775 times.
✓ Branch 1 taken 1740 times.
|
8247 | do kj = 1, kernel_w |
| 270 | 5775 | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 271 |
1/2✓ Branch 0 taken 5775 times.
✗ Branch 1 not taken.
|
7515 | if(j_in .ge. 1 .and. j_in .le. input_w)then |
| 272 |
2/2✓ Branch 0 taken 19607 times.
✓ Branch 1 taken 5775 times.
|
25382 | do ki = 1, kernel_h |
| 273 | 19607 | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 274 |
1/2✓ Branch 0 taken 19607 times.
✗ Branch 1 not taken.
|
25382 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 275 | 19607 | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 276 | 19607 | k_idx = ki + (kj - 1) * kernel_h + k_base_idx | |
| 277 | ✗ | conv_sum = conv_sum + input%val(in_idx, s) * & | |
| 278 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 19607 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 19607 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 19607 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 19607 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 19607 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 19607 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 19607 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 19607 times.
|
19607 | kernel%val(k_idx, 1) |
| 279 | end if | ||
| 280 | end do | ||
| 281 | end if | ||
| 282 | end do | ||
| 283 | end do | ||
| 284 | 732 | out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out | |
| 285 |
12/16✓ Branch 0 taken 42 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 174 times.
✓ Branch 3 taken 42 times.
✓ Branch 4 taken 732 times.
✓ Branch 5 taken 174 times.
✓ Branch 6 taken 732 times.
✓ Branch 7 taken 732 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 732 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 732 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 732 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 732 times.
|
1698 | output%val(out_idx, s) = conv_sum |
| 286 | end do | ||
| 287 | |||
| 288 | ! Store parameters for backward pass | ||
| 289 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
|
18 | allocate(output%indices(2)) |
| 290 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | output%indices(1) = num_channels |
| 291 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
|
18 | output%indices(2) = num_filters |
| 292 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
|
18 | allocate(output%adj_ja(2,3)) |
| 293 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✓ Branch 18 taken 36 times.
✓ Branch 19 taken 18 times.
|
54 | output%adj_ja(1:2,1) = stride |
| 294 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✓ Branch 18 taken 36 times.
✓ Branch 19 taken 18 times.
|
54 | output%adj_ja(1:2,2) = dilation |
| 295 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
|
18 | output%adj_ja(1,3) = kernel_h |
| 296 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
|
18 | output%adj_ja(2,3) = kernel_w |
| 297 | |||
| 298 | |||
| 299 | 18 | output%get_partial_left => get_partial_conv2d_input | |
| 300 | 18 | output%get_partial_right => get_partial_conv2d_kernel | |
| 301 | 18 | output%get_partial_left_val => get_partial_conv2d_input_val | |
| 302 | 18 | output%get_partial_right_val => get_partial_conv2d_kernel_val | |
| 303 |
1/2✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
|
18 | if(input%requires_grad .or. kernel%requires_grad)then |
| 304 | 18 | output%requires_grad = .true. | |
| 305 | 18 | output%is_forward = input%is_forward | |
| 306 | 18 | output%operation = 'conv2d' | |
| 307 | 18 | output%left_operand => input | |
| 308 | 18 | output%right_operand => kernel | |
| 309 | end if | ||
| 310 | |||
| 311 | 18 | end function conv2d | |
| 312 | !------------------------------------------------------------------------------- | ||
| 313 | ✗ | function get_partial_conv2d_input(this, upstream_grad) result(output) | |
| 314 | !! Get partial derivative wrt input for 2D convolution | ||
| 315 | implicit none | ||
| 316 | |||
| 317 | class(array_type), intent(inout) :: this | ||
| 318 | type(array_type), intent(in) :: upstream_grad | ||
| 319 | type(array_type) :: output | ||
| 320 | |||
| 321 | ✗ | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 322 | ✗ | size(this%left_operand%val, dim=2) ]) | |
| 323 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 324 | |||
| 325 | ✗ | end function get_partial_conv2d_input | |
| 326 | !------------------------------------------------------------------------------- | ||
| 327 | ✗ | function get_partial_conv2d_kernel(this, upstream_grad) result(output) | |
| 328 | !! Get partial derivative wrt kernel for 2D convolution | ||
| 329 | implicit none | ||
| 330 | |||
| 331 | class(array_type), intent(inout) :: this | ||
| 332 | type(array_type), intent(in) :: upstream_grad | ||
| 333 | type(array_type) :: output | ||
| 334 | |||
| 335 | ✗ | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 336 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 337 | |||
| 338 | ✗ | end function get_partial_conv2d_kernel | |
| 339 | !------------------------------------------------------------------------------- | ||
| 340 |
2/4✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13 times.
✗ Branch 3 not taken.
|
13 | pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output) |
| 341 | !! Get partial derivative wrt input for 2D convolution (subroutine version) | ||
| 342 | implicit none | ||
| 343 | |||
| 344 | class(array_type), intent(in) :: this | ||
| 345 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 346 | real(real32), dimension(:,:), intent(out) :: output | ||
| 347 | |||
| 348 | ! Local variables | ||
| 349 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 350 | integer :: i_in, j_in, k_idx, out_idx, in_idx | ||
| 351 | integer :: in_base_idx, k_base_idx, kernel_channel_size | ||
| 352 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 353 | integer :: output_h, output_w, num_channels, num_filters | ||
| 354 | integer, dimension(2) :: stride, dilation | ||
| 355 | integer :: channel_size_in, channel_size_out | ||
| 356 | real(real32) :: grad_val, kernel_val | ||
| 357 | |||
| 358 | |||
| 359 | ! Unpack parameters | ||
| 360 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | num_channels = this%indices(1) |
| 361 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | num_filters = this%indices(2) |
| 362 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 26 times.
✓ Branch 19 taken 13 times.
|
39 | stride = this%adj_ja(1:2,1) |
| 363 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 26 times.
✓ Branch 19 taken 13 times.
|
39 | dilation = this%adj_ja(1:2,2) |
| 364 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
|
13 | kernel_h = this%adj_ja(1,3) |
| 365 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
|
13 | kernel_w = this%adj_ja(2,3) |
| 366 | |||
| 367 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | input_h = this%left_operand%shape(1) |
| 368 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | input_w = this%left_operand%shape(2) |
| 369 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | output_h = this%shape(1) |
| 370 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | output_w = this%shape(2) |
| 371 | 13 | channel_size_in = input_h * input_w | |
| 372 | 13 | channel_size_out = output_h * output_w | |
| 373 | 13 | kernel_channel_size = kernel_h * kernel_w | |
| 374 | |||
| 375 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
✓ Branch 20 taken 159 times.
✓ Branch 21 taken 13 times.
|
185 | output = 0._real32 |
| 376 | |||
| 377 | ! Parallelised over batch, output channels, output spatial dims | ||
| 378 | ✗ | do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, & | |
| 379 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 13 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 13 times.
|
13 | j = 1:output_w, i = 1:output_h) |
| 380 | 84 | out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out | |
| 381 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 84 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 84 times.
|
84 | grad_val = upstream_grad(out_idx, s) |
| 382 | |||
| 383 |
10/10✓ Branch 0 taken 15 times.
✓ Branch 1 taken 13 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 15 times.
✓ Branch 4 taken 84 times.
✓ Branch 5 taken 21 times.
✓ Branch 6 taken 84 times.
✓ Branch 7 taken 84 times.
✓ Branch 8 taken 82 times.
✓ Branch 9 taken 2 times.
|
299 | if(abs(grad_val) .gt. 1.e-30_real32)then |
| 384 |
2/2✓ Branch 0 taken 298 times.
✓ Branch 1 taken 82 times.
|
380 | do c_in = 1, num_channels |
| 385 | 298 | in_base_idx = (c_in - 1) * channel_size_in | |
| 386 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 387 | 298 | (c_out - 1) * kernel_channel_size * num_channels | |
| 388 | |||
| 389 |
2/2✓ Branch 0 taken 1163 times.
✓ Branch 1 taken 298 times.
|
1543 | do kj = 1, kernel_w |
| 390 | 1163 | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 391 |
1/2✓ Branch 0 taken 1163 times.
✗ Branch 1 not taken.
|
1461 | if(j_in .ge. 1 .and. j_in .le. input_w)then |
| 392 |
2/2✓ Branch 0 taken 4621 times.
✓ Branch 1 taken 1163 times.
|
5784 | do ki = 1, kernel_h |
| 393 | 4621 | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 394 |
1/2✓ Branch 0 taken 4621 times.
✗ Branch 1 not taken.
|
5784 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 395 | 4621 | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 396 | k_idx = (kernel_h - ki + 1) + & | ||
| 397 | 4621 | (kernel_w - kj) * kernel_h + k_base_idx | |
| 398 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 4621 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4621 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4621 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4621 times.
|
4621 | kernel_val = this%right_operand%val(k_idx, 1) |
| 399 |
4/8✗ Branch 1 not taken.
✓ Branch 2 taken 4621 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4621 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4621 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4621 times.
|
18484 | output(in_idx, s) = output(in_idx, s) + & |
| 400 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 4621 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4621 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4621 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4621 times.
|
23105 | grad_val * kernel_val |
| 401 | end if | ||
| 402 | end do | ||
| 403 | end if | ||
| 404 | end do | ||
| 405 | end do | ||
| 406 | end if | ||
| 407 | end do | ||
| 408 | |||
| 409 | 13 | end subroutine get_partial_conv2d_input_val | |
| 410 | !------------------------------------------------------------------------------- | ||
| 411 |
2/4✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
|
14 | pure subroutine get_partial_conv2d_kernel_val(this, upstream_grad, output) |
| 412 | !! Get partial derivative wrt kernel for 2D convolution (subroutine version) | ||
| 413 | implicit none | ||
| 414 | |||
| 415 | class(array_type), intent(in) :: this | ||
| 416 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 417 | real(real32), dimension(:,:), intent(out) :: output | ||
| 418 | |||
| 419 | ! Local variables | ||
| 420 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 421 | integer :: i_in, j_in, k_idx, out_idx, in_idx | ||
| 422 | integer :: in_base_idx, out_base_idx, k_base_idx, kernel_channel_size | ||
| 423 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 424 | integer :: output_h, output_w, num_channels, num_filters | ||
| 425 | integer, dimension(2) :: stride, dilation | ||
| 426 | integer :: channel_size_in, channel_size_out | ||
| 427 | real(real32) :: grad_sum | ||
| 428 | |||
| 429 | |||
| 430 | ! Unpack parameters | ||
| 431 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | num_channels = this%indices(1) |
| 432 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | num_filters = this%indices(2) |
| 433 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 28 times.
✓ Branch 19 taken 14 times.
|
42 | stride = this%adj_ja(1:2,1) |
| 434 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 28 times.
✓ Branch 19 taken 14 times.
|
42 | dilation = this%adj_ja(1:2,2) |
| 435 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
|
14 | kernel_h = this%adj_ja(1,3) |
| 436 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
|
14 | kernel_w = this%adj_ja(2,3) |
| 437 | |||
| 438 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | input_h = this%left_operand%shape(1) |
| 439 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | input_w = this%left_operand%shape(2) |
| 440 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | output_h = this%shape(1) |
| 441 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | output_w = this%shape(2) |
| 442 | 14 | channel_size_in = input_h * input_w | |
| 443 | 14 | channel_size_out = output_h * output_w | |
| 444 | 14 | kernel_channel_size = kernel_h * kernel_w | |
| 445 | |||
| 446 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 14 times.
✓ Branch 20 taken 635 times.
✓ Branch 21 taken 14 times.
|
663 | output = 0._real32 |
| 447 | |||
| 448 | ! Parallelised over filters, channels, and kernel dimensions | ||
| 449 | do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, & | ||
| 450 | 14 | kj = 1:kernel_w, ki = 1:kernel_h) | |
| 451 | 635 | out_base_idx = (c_out - 1) * channel_size_out | |
| 452 | 635 | in_base_idx = (c_in - 1) * channel_size_in | |
| 453 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 454 | 635 | (c_out - 1) * kernel_channel_size * num_channels | |
| 455 | 635 | k_idx = ki + (kj - 1) * kernel_h + k_base_idx | |
| 456 | |||
| 457 | 635 | grad_sum = 0._real32 | |
| 458 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 635 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 635 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 635 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 635 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 635 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 635 times.
✓ Branch 18 taken 635 times.
✓ Branch 19 taken 635 times.
|
1270 | do s = 1, size(upstream_grad, dim=2) |
| 459 |
2/2✓ Branch 0 taken 2199 times.
✓ Branch 1 taken 635 times.
|
3469 | do j = 1, output_w |
| 460 | 2199 | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 461 |
1/2✓ Branch 0 taken 2199 times.
✗ Branch 1 not taken.
|
2834 | if(j_in .ge. 1 .and. j_in .le. input_w)then |
| 462 |
2/2✓ Branch 0 taken 8511 times.
✓ Branch 1 taken 2199 times.
|
10710 | do i = 1, output_h |
| 463 | 8511 | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 464 |
1/2✓ Branch 0 taken 8511 times.
✗ Branch 1 not taken.
|
10710 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 465 | 8511 | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 466 | 8511 | out_idx = i + (j - 1) * output_h + out_base_idx | |
| 467 | grad_sum = grad_sum + & | ||
| 468 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 8511 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8511 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8511 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8511 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8511 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8511 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8511 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8511 times.
|
8511 | upstream_grad(out_idx, s) * this%left_operand%val(in_idx, s) |
| 469 | end if | ||
| 470 | end do | ||
| 471 | end if | ||
| 472 | end do | ||
| 473 | end do | ||
| 474 |
11/14✓ Branch 0 taken 20 times.
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 40 times.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 106 times.
✓ Branch 5 taken 40 times.
✓ Branch 6 taken 635 times.
✓ Branch 7 taken 106 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 635 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 635 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 635 times.
|
815 | output(k_idx, 1) = grad_sum |
| 475 | end do | ||
| 476 | |||
| 477 | 14 | end subroutine get_partial_conv2d_kernel_val | |
| 478 | !############################################################################### | ||
| 479 | |||
| 480 | |||
| 481 | !############################################################################### | ||
| 482 | 16 | module function conv3d(input, kernel, stride, dilation) result(output) | |
| 483 | !! 3D convolution operation | ||
| 484 | implicit none | ||
| 485 | |||
| 486 | ! Arguments | ||
| 487 | type(array_type), intent(in), target :: input | ||
| 488 | type(array_type), intent(in), target :: kernel | ||
| 489 | integer, dimension(3), intent(in) :: stride | ||
| 490 | integer, dimension(3), intent(in) :: dilation | ||
| 491 | type(array_type), pointer :: output | ||
| 492 | |||
| 493 | ! Local variables | ||
| 494 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 495 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 496 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 497 | integer :: output_h, output_w, output_d | ||
| 498 | integer :: num_channels, num_filters | ||
| 499 | integer :: channel_size_in, channel_size_out | ||
| 500 | real(real32) :: conv_sum | ||
| 501 | integer, dimension(5) :: output_shape | ||
| 502 | |||
| 503 | ! Extract dimensions | ||
| 504 | ! input: [H_in, W_in, D_in, C_in, B] | ||
| 505 | ! kernel: [K_h, K_w, K_d, C_in, C_out] | ||
| 506 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | input_h = input%shape(1) |
| 507 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | input_w = input%shape(2) |
| 508 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | input_d = input%shape(3) |
| 509 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | num_channels = input%shape(4) |
| 510 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | kernel_h = kernel%shape(1) |
| 511 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | kernel_w = kernel%shape(2) |
| 512 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | kernel_d = kernel%shape(3) |
| 513 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | num_filters = kernel%shape(5) |
| 514 | |||
| 515 | ! Calculate output dimensions | ||
| 516 | output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / & | ||
| 517 | 16 | stride(1) + 1 | |
| 518 | output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / & | ||
| 519 | 16 | stride(2) + 1 | |
| 520 | output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / & | ||
| 521 | 16 | stride(3) + 1 | |
| 522 | output_shape = [output_h, output_w, output_d, num_filters, & | ||
| 523 |
2/2✓ Branch 0 taken 80 times.
✓ Branch 1 taken 16 times.
|
96 | size(input%val, dim=2)] |
| 524 | |||
| 525 | 16 | output => input%create_result(array_shape = output_shape) | |
| 526 |
4/4✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 2172 times.
✓ Branch 3 taken 16 times.
|
2204 | output%val = 0._real32 |
| 527 | |||
| 528 | 16 | channel_size_in = input_h * input_w * input_d | |
| 529 | 16 | channel_size_out = output_h * output_w * output_d | |
| 530 | |||
| 531 | ! Perform convolution - optimised with do concurrent | ||
| 532 | do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, & | ||
| 533 | 16 | k = 1:output_d, j = 1:output_w, i = 1:output_h) | |
| 534 | |||
| 535 | 2172 | conv_sum = 0._real32 | |
| 536 |
2/2✓ Branch 0 taken 6924 times.
✓ Branch 1 taken 2172 times.
|
9096 | do c_in = 1, num_channels |
| 537 |
2/2✓ Branch 0 taken 22479 times.
✓ Branch 1 taken 6924 times.
|
31575 | do kk = 1, kernel_d |
| 538 | 22479 | k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1 | |
| 539 |
1/2✓ Branch 0 taken 22479 times.
✗ Branch 1 not taken.
|
29403 | if(k_in .ge. 1 .and. k_in .le. input_d)then |
| 540 |
2/2✓ Branch 0 taken 74327 times.
✓ Branch 1 taken 22479 times.
|
96806 | do kj = 1, kernel_w |
| 541 | 74327 | j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 542 |
1/2✓ Branch 0 taken 74327 times.
✗ Branch 1 not taken.
|
96806 | if(j_in .ge. 1 .and. j_in .le. input_w)then |
| 543 |
2/2✓ Branch 0 taken 250605 times.
✓ Branch 1 taken 74327 times.
|
324932 | do ki = 1, kernel_h |
| 544 | i_in = ( i - 1 ) * stride(1) + & | ||
| 545 | 250605 | ( ki - 1 ) * dilation(1) + 1 | |
| 546 |
1/2✓ Branch 0 taken 250605 times.
✗ Branch 1 not taken.
|
324932 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 547 | in_idx = i_in + ( j_in - 1 ) * input_h + & | ||
| 548 | ( k_in - 1 ) * input_h * input_w + & | ||
| 549 | 250605 | ( c_in - 1 ) * channel_size_in | |
| 550 | k_idx = ki + ( kj - 1 ) * kernel_h + & | ||
| 551 | ( kk - 1 ) * kernel_h * kernel_w + & | ||
| 552 | ( c_in - 1 ) * kernel_h * kernel_w * & | ||
| 553 | kernel_d + & | ||
| 554 | ( c_out - 1 ) * kernel_h * kernel_w * & | ||
| 555 | 250605 | kernel_d * num_channels | |
| 556 | ✗ | conv_sum = conv_sum + input%val(in_idx, s) * & | |
| 557 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 250605 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 250605 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 250605 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 250605 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 250605 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 250605 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 250605 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 250605 times.
|
250605 | kernel%val(k_idx, 1) |
| 558 | end if | ||
| 559 | end do | ||
| 560 | end if | ||
| 561 | end do | ||
| 562 | end if | ||
| 563 | end do | ||
| 564 | end do | ||
| 565 | out_idx = i + ( j - 1 ) * output_h + & | ||
| 566 | ( k - 1 ) * output_h * output_w + & | ||
| 567 | 2172 | ( c_out - 1 ) * channel_size_out | |
| 568 |
14/18✓ Branch 0 taken 30 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 102 times.
✓ Branch 3 taken 30 times.
✓ Branch 4 taken 498 times.
✓ Branch 5 taken 102 times.
✓ Branch 6 taken 2172 times.
✓ Branch 7 taken 498 times.
✓ Branch 8 taken 2172 times.
✓ Branch 9 taken 2172 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2172 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2172 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2172 times.
|
4990 | output%val(out_idx, s) = conv_sum |
| 569 | end do | ||
| 570 | |||
| 571 | ! Store parameters for backward pass | ||
| 572 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
|
16 | allocate(output%indices(2)) |
| 573 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | output%indices(1) = num_channels |
| 574 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
|
16 | output%indices(2) = num_filters |
| 575 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
|
16 | allocate(output%adj_ja(3,3)) |
| 576 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✓ Branch 18 taken 48 times.
✓ Branch 19 taken 16 times.
|
64 | output%adj_ja(1:3,1) = stride |
| 577 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✓ Branch 18 taken 48 times.
✓ Branch 19 taken 16 times.
|
64 | output%adj_ja(1:3,2) = dilation |
| 578 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
|
16 | output%adj_ja(1,3) = kernel_h |
| 579 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
|
16 | output%adj_ja(2,3) = kernel_w |
| 580 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
|
16 | output%adj_ja(3,3) = kernel_d |
| 581 | |||
| 582 | 16 | output%get_partial_left => get_partial_conv3d_input | |
| 583 | 16 | output%get_partial_right => get_partial_conv3d_kernel | |
| 584 | 16 | output%get_partial_left_val => get_partial_conv3d_input_val | |
| 585 | 16 | output%get_partial_right_val => get_partial_conv3d_kernel_val | |
| 586 |
1/2✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
|
16 | if(input%requires_grad .or. kernel%requires_grad)then |
| 587 | 16 | output%requires_grad = .true. | |
| 588 | 16 | output%is_forward = input%is_forward | |
| 589 | 16 | output%operation = 'conv3d' | |
| 590 | 16 | output%left_operand => input | |
| 591 | 16 | output%right_operand => kernel | |
| 592 | end if | ||
| 593 | |||
| 594 | 16 | end function conv3d | |
| 595 | !------------------------------------------------------------------------------- | ||
| 596 | ✗ | function get_partial_conv3d_input(this, upstream_grad) result(output) | |
| 597 | !! Get partial derivative wrt input for 3D convolution | ||
| 598 | implicit none | ||
| 599 | |||
| 600 | class(array_type), intent(inout) :: this | ||
| 601 | type(array_type), intent(in) :: upstream_grad | ||
| 602 | type(array_type) :: output | ||
| 603 | |||
| 604 | ✗ | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 605 | ✗ | size(this%left_operand%val, dim=2) ]) | |
| 606 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 607 | |||
| 608 | ✗ | end function get_partial_conv3d_input | |
| 609 | !------------------------------------------------------------------------------- | ||
| 610 | ✗ | function get_partial_conv3d_kernel(this, upstream_grad) result(output) | |
| 611 | !! Get partial derivative wrt kernel for 3D convolution | ||
| 612 | implicit none | ||
| 613 | |||
| 614 | class(array_type), intent(inout) :: this | ||
| 615 | type(array_type), intent(in) :: upstream_grad | ||
| 616 | type(array_type) :: output | ||
| 617 | |||
| 618 | ✗ | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 619 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 620 | |||
| 621 | ✗ | end function get_partial_conv3d_kernel | |
| 622 | !------------------------------------------------------------------------------- | ||
| 623 |
2/4✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13 times.
✗ Branch 3 not taken.
|
13 | pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output) |
| 624 | !! Get partial derivative wrt input for 3D convolution (subroutine version) | ||
| 625 | implicit none | ||
| 626 | |||
| 627 | class(array_type), intent(in) :: this | ||
| 628 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 629 | real(real32), dimension(:,:), intent(out) :: output | ||
| 630 | |||
| 631 | ! Local variables | ||
| 632 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 633 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 634 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 635 | integer :: output_h, output_w, output_d | ||
| 636 | integer :: num_channels, num_filters | ||
| 637 | integer :: channel_size_in, channel_size_out | ||
| 638 | integer, dimension(3) :: stride, dilation | ||
| 639 | real(real32) :: grad_val, kernel_val | ||
| 640 | |||
| 641 | |||
| 642 | ! Unpack parameters | ||
| 643 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | num_channels = this%indices(1) |
| 644 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | num_filters = this%indices(2) |
| 645 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 39 times.
✓ Branch 19 taken 13 times.
|
52 | stride = this%adj_ja(1:3,1) |
| 646 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 39 times.
✓ Branch 19 taken 13 times.
|
52 | dilation = this%adj_ja(1:3,2) |
| 647 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
|
13 | kernel_h = this%adj_ja(1,3) |
| 648 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
|
13 | kernel_w = this%adj_ja(2,3) |
| 649 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
|
13 | kernel_d = this%adj_ja(3,3) |
| 650 | |||
| 651 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | input_h = this%left_operand%shape(1) |
| 652 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | input_w = this%left_operand%shape(2) |
| 653 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | input_d = this%left_operand%shape(3) |
| 654 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | output_h = this%shape(1) |
| 655 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | output_w = this%shape(2) |
| 656 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | output_d = this%shape(3) |
| 657 | |||
| 658 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
✓ Branch 20 taken 883 times.
✓ Branch 21 taken 13 times.
|
909 | output = 0._real32 |
| 659 | |||
| 660 | 13 | channel_size_in = input_h * input_w * input_d | |
| 661 | 13 | channel_size_out = output_h * output_w * output_d | |
| 662 | |||
| 663 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
|
26 | do s = 1, size(upstream_grad, dim=2) |
| 664 |
2/2✓ Branch 0 taken 16 times.
✓ Branch 1 taken 13 times.
|
42 | do c_in = 1, num_channels |
| 665 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 16 times.
|
53 | do k = 1, output_d |
| 666 |
2/2✓ Branch 0 taken 48 times.
✓ Branch 1 taken 24 times.
|
88 | do j = 1, output_w |
| 667 |
2/2✓ Branch 0 taken 120 times.
✓ Branch 1 taken 48 times.
|
192 | do i = 1, output_h |
| 668 |
2/2✓ Branch 0 taken 876 times.
✓ Branch 1 taken 120 times.
|
1044 | do c_out = 1, num_filters |
| 669 | out_idx = i + ( j - 1 ) * output_h + & | ||
| 670 | ( k - 1 ) * output_h * output_w + & | ||
| 671 | 876 | ( c_out - 1 ) * channel_size_out | |
| 672 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 876 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 876 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 876 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 876 times.
|
876 | grad_val = upstream_grad(out_idx, s) |
| 673 | |||
| 674 |
2/2✓ Branch 0 taken 3469 times.
✓ Branch 1 taken 876 times.
|
4465 | do kk = 1, kernel_d |
| 675 | k_in = ( k - 1 ) * stride(3) + & | ||
| 676 | 3469 | ( kk - 1 ) * dilation(3) + 1 | |
| 677 |
1/2✓ Branch 0 taken 3469 times.
✗ Branch 1 not taken.
|
4345 | if( k_in .ge. 1 .and. k_in .le. input_d )then |
| 678 |
2/2✓ Branch 0 taken 13839 times.
✓ Branch 1 taken 3469 times.
|
17308 | do kj = 1, kernel_w |
| 679 | j_in = ( j - 1 ) * stride(2) + & | ||
| 680 | 13839 | ( kj - 1 ) * dilation(2) + 1 | |
| 681 |
1/2✓ Branch 0 taken 13839 times.
✗ Branch 1 not taken.
|
17308 | if( j_in .ge. 1 .and. j_in .le. input_w )then |
| 682 |
2/2✓ Branch 0 taken 55315 times.
✓ Branch 1 taken 13839 times.
|
69154 | do ki = 1, kernel_h |
| 683 | i_in = ( i - 1 ) * stride(1) + & | ||
| 684 | 55315 | ( ki - 1 ) * dilation(1) + 1 | |
| 685 |
1/2✓ Branch 0 taken 55315 times.
✗ Branch 1 not taken.
|
55315 | if( i_in .ge. 1 .and. & |
| 686 | 13839 | i_in .le. input_h )then | |
| 687 | in_idx = i_in + & | ||
| 688 | ( j_in - 1 ) * input_h + & | ||
| 689 | ( k_in - 1 ) * input_h * input_w + & | ||
| 690 | 55315 | ( c_in - 1 ) * channel_size_in | |
| 691 | k_idx = ki + ( kj - 1 ) * kernel_h + & | ||
| 692 | ( kk - 1 ) * kernel_h * kernel_w + & | ||
| 693 | ( c_in - 1 ) * kernel_h * & | ||
| 694 | kernel_w * kernel_d + & | ||
| 695 | ( c_out - 1 ) * kernel_h * & | ||
| 696 | 55315 | kernel_w * kernel_d * num_channels | |
| 697 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 55315 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 55315 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 55315 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 55315 times.
|
55315 | kernel_val = this%right_operand%val(k_idx, 1) |
| 698 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 55315 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 55315 times.
|
110630 | output(in_idx, s) = & |
| 699 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 55315 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 55315 times.
|
110630 | output(in_idx, s) + & |
| 700 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 55315 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 55315 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 55315 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 55315 times.
|
276575 | grad_val * kernel_val |
| 701 | end if | ||
| 702 | end do | ||
| 703 | end if | ||
| 704 | end do | ||
| 705 | end if | ||
| 706 | end do | ||
| 707 | end do | ||
| 708 | end do | ||
| 709 | end do | ||
| 710 | end do | ||
| 711 | end do | ||
| 712 | end do | ||
| 713 | |||
| 714 | 13 | end subroutine get_partial_conv3d_input_val | |
| 715 | !------------------------------------------------------------------------------- | ||
| 716 |
2/4✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
|
14 | pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output) |
| 717 | !! Get partial derivative wrt kernel for 3D convolution (subroutine version) | ||
| 718 | implicit none | ||
| 719 | |||
| 720 | class(array_type), intent(in) :: this | ||
| 721 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 722 | real(real32), dimension(:,:), intent(out) :: output | ||
| 723 | |||
| 724 | ! Local variables | ||
| 725 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 726 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 727 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 728 | integer :: output_h, output_w, output_d | ||
| 729 | integer :: num_channels, num_filters | ||
| 730 | integer :: channel_size_in, channel_size_out | ||
| 731 | integer, dimension(3) :: stride, dilation | ||
| 732 | real(real32) :: grad_sum | ||
| 733 | |||
| 734 | |||
| 735 | ! Unpack parameters | ||
| 736 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | num_channels = this%indices(1) |
| 737 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | num_filters = this%indices(2) |
| 738 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 42 times.
✓ Branch 19 taken 14 times.
|
56 | stride = this%adj_ja(1:3,1) |
| 739 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 42 times.
✓ Branch 19 taken 14 times.
|
56 | dilation = this%adj_ja(1:3,2) |
| 740 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
|
14 | kernel_h = this%adj_ja(1,3) |
| 741 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
|
14 | kernel_w = this%adj_ja(2,3) |
| 742 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
|
14 | kernel_d = this%adj_ja(3,3) |
| 743 | |||
| 744 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | input_h = this%left_operand%shape(1) |
| 745 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | input_w = this%left_operand%shape(2) |
| 746 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | input_d = this%left_operand%shape(3) |
| 747 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | output_h = this%shape(1) |
| 748 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | output_w = this%shape(2) |
| 749 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
|
14 | output_d = this%shape(3) |
| 750 | |||
| 751 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 14 times.
✓ Branch 20 taken 2391 times.
✓ Branch 21 taken 14 times.
|
2419 | output = 0._real32 |
| 752 | |||
| 753 | 14 | channel_size_in = input_h * input_w * input_d | |
| 754 | 14 | channel_size_out = output_h * output_w * output_d | |
| 755 | |||
| 756 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 14 times.
|
38 | do c_out = 1, num_filters |
| 757 |
2/2✓ Branch 0 taken 56 times.
✓ Branch 1 taken 24 times.
|
94 | do c_in = 1, num_channels |
| 758 |
2/2✓ Branch 0 taken 177 times.
✓ Branch 1 taken 56 times.
|
257 | do kk = 1, kernel_d |
| 759 |
2/2✓ Branch 0 taken 635 times.
✓ Branch 1 taken 177 times.
|
868 | do kj = 1, kernel_w |
| 760 |
2/2✓ Branch 0 taken 2391 times.
✓ Branch 1 taken 635 times.
|
3203 | do ki = 1, kernel_h |
| 761 | k_idx = ki + (kj-1)*kernel_h + & | ||
| 762 | (kk-1)*kernel_h*kernel_w + & | ||
| 763 | (c_in-1)*kernel_h*kernel_w*kernel_d + & | ||
| 764 | 2391 | (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels | |
| 765 | |||
| 766 | 2391 | grad_sum = 0._real32 | |
| 767 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 2391 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2391 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2391 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2391 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2391 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2391 times.
✓ Branch 18 taken 2391 times.
✓ Branch 19 taken 2391 times.
|
4782 | do s = 1, size(upstream_grad, dim=2) |
| 768 |
2/2✓ Branch 0 taken 8107 times.
✓ Branch 1 taken 2391 times.
|
12889 | do k = 1, output_d |
| 769 | 8107 | k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1 | |
| 770 |
1/2✓ Branch 0 taken 8107 times.
✗ Branch 1 not taken.
|
10498 | if(k_in .ge. 1 .and. k_in .le. input_d)then |
| 771 |
2/2✓ Branch 0 taken 30115 times.
✓ Branch 1 taken 8107 times.
|
38222 | do j = 1, output_w |
| 772 | j_in = (j-1)*stride(2) + & | ||
| 773 | 30115 | (kj-1)*dilation(2) + 1 | |
| 774 |
1/2✓ Branch 0 taken 30115 times.
✗ Branch 1 not taken.
|
38222 | if(j_in .ge. 1 .and. j_in .le. input_w)then |
| 775 |
2/2✓ Branch 0 taken 125299 times.
✓ Branch 1 taken 30115 times.
|
155414 | do i = 1, output_h |
| 776 | i_in = (i-1)*stride(1) + & | ||
| 777 | 125299 | (ki-1)*dilation(1) + 1 | |
| 778 |
1/2✓ Branch 0 taken 125299 times.
✗ Branch 1 not taken.
|
155414 | if(i_in .ge. 1 .and. i_in .le. input_h)then |
| 779 | in_idx = i_in + (j_in-1)*input_h + & | ||
| 780 | (k_in-1)*input_h*input_w + & | ||
| 781 | 125299 | (c_in-1)*channel_size_in | |
| 782 | out_idx = i + (j-1)*output_h + & | ||
| 783 | (k-1)*output_h*output_w + & | ||
| 784 | 125299 | (c_out-1)*channel_size_out | |
| 785 | grad_sum = grad_sum + & | ||
| 786 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 125299 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125299 times.
|
250598 | upstream_grad(out_idx, s) * & |
| 787 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 125299 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 125299 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 125299 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125299 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 125299 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 125299 times.
|
375897 | this%left_operand%val(in_idx, s) |
| 788 | end if | ||
| 789 | end do | ||
| 790 | end if | ||
| 791 | end do | ||
| 792 | end if | ||
| 793 | end do | ||
| 794 | end do | ||
| 795 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2391 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2391 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2391 times.
|
3026 | output(k_idx, 1) = grad_sum |
| 796 | end do | ||
| 797 | end do | ||
| 798 | end do | ||
| 799 | end do | ||
| 800 | end do | ||
| 801 | |||
| 802 | 14 | end subroutine get_partial_conv3d_kernel_val | |
| 803 | !############################################################################### | ||
| 804 | |||
| 805 | end submodule athena__diffstruc_extd_submodule_conv | ||
| 806 |