GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_diffstruc_extd_sub_conv.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 334 366 91.3%
Functions: 0 0 -%
Branches: 723 1348 53.6%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_conv
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 6 module function conv1d(input, kernel, stride, dilation) result(output)
8 !! 1D convolution operation
9 implicit none
10
11 ! Arguments
12 type(array_type), intent(in), target :: input
13 type(array_type), intent(in), target :: kernel
14 integer, intent(in) :: stride
15 integer, intent(in) :: dilation
16 type(array_type), pointer :: output
17
18 ! Local variables
19 integer :: i, k, c_in, c_out, s
20 integer :: i_in, k_idx
21 integer :: input_h, kernel_h, output_h, num_channels, num_filters
22 real(real32) :: conv_sum
23 integer, dimension(3) :: output_shape
24
25 ! Extract dimensions
26 ! input: [H_in, C_in, B]
27 ! kernel: [K, C_in, C_out]
28
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 input_h = input%shape(1)
29
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 num_channels = input%shape(2)
30
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 kernel_h = kernel%shape(1)
31
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 num_filters = kernel%shape(3)
32
33 ! Calculate output dimensions
34 output_h = (input_h - dilation*(kernel_h - 1) - 1) / &
35 6 stride + 1
36
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 6 times.
24 output_shape = [output_h, num_filters, size(input%val, dim=2)]
37
38 6 output => input%create_result(array_shape = output_shape)
39
4/4
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 100 times.
✓ Branch 3 taken 6 times.
112 output%val = 0._real32
40
41 ! Perform convolution
42 do concurrent(s = 1:output_shape(3), c_out = 1:num_filters, &
43 6 i = 1:output_h)
44 100 conv_sum = 0._real32
45
2/2
✓ Branch 0 taken 340 times.
✓ Branch 1 taken 100 times.
440 do c_in = 1, num_channels
46
2/2
✓ Branch 0 taken 1209 times.
✓ Branch 1 taken 340 times.
1649 do k = 1, kernel_h
47 1209 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
48
1/2
✓ Branch 0 taken 1209 times.
✗ Branch 1 not taken.
1549 if(i_in .ge. 1 .and. i_in .le. input_h)then
49 k_idx = k + ( c_in - 1 ) * kernel_h + &
50 1209 ( c_out - 1 ) * kernel_h * num_channels
51 conv_sum = conv_sum + &
52 input%val(i_in + ( c_in - 1 ) * input_h, s) * &
53
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1209 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1209 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1209 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1209 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1209 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1209 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1209 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1209 times.
1209 kernel%val(k_idx, 1)
54 end if
55 end do
56 end do
57
10/14
✓ Branch 0 taken 22 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 100 times.
✓ Branch 3 taken 22 times.
✓ Branch 4 taken 100 times.
✓ Branch 5 taken 100 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 100 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 100 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 100 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 100 times.
228 output%val(i + (c_out-1)*output_h, s) = conv_sum
58 end do
59
60 ! Store parameters for backward pass
61
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
6 allocate(output%indices(2))
62
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 output%indices(1) = num_channels
63
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 output%indices(2) = num_filters
64
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
6 allocate(output%adj_ja(1,3))
65
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
6 output%adj_ja(1,1) = stride
66
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
6 output%adj_ja(1,2) = dilation
67
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
6 output%adj_ja(1,3) = kernel_h
68
69 6 output%get_partial_left => get_partial_conv1d_input
70 6 output%get_partial_right => get_partial_conv1d_kernel
71 6 output%get_partial_left_val => get_partial_conv1d_input_val
72 6 output%get_partial_right_val => get_partial_conv1d_kernel_val
73
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 if(input%requires_grad .or. kernel%requires_grad)then
74 6 output%requires_grad = .true.
75 6 output%is_forward = input%is_forward
76 6 output%operation = 'conv1d'
77 6 output%left_operand => input
78 6 output%right_operand => kernel
79 end if
80
81 6 end function conv1d
82 !-------------------------------------------------------------------------------
83 function get_partial_conv1d_input(this, upstream_grad) result(output)
84 !! Get partial derivative wrt input for 1D convolution
85 implicit none
86
87 class(array_type), intent(inout) :: this
88 type(array_type), intent(in) :: upstream_grad
89 type(array_type) :: output
90
91 call output%allocate(array_shape = [ this%left_operand%shape, &
92 size(this%left_operand%val, dim=2) ])
93 call this%get_partial_left_val(upstream_grad%val, output%val)
94
95 end function get_partial_conv1d_input
96 !-------------------------------------------------------------------------------
97 function get_partial_conv1d_kernel(this, upstream_grad) result(output)
98 !! Get partial derivative wrt kernel for 1D convolution
99 implicit none
100
101 class(array_type), intent(inout) :: this
102 type(array_type), intent(in) :: upstream_grad
103 type(array_type) :: output
104
105 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
106 call this%get_partial_right_val(upstream_grad%val, output%val)
107
108 end function get_partial_conv1d_kernel
109 !-------------------------------------------------------------------------------
110
2/4
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
2 pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output)
111 !! Get partial derivative wrt input for 1D convolution (subroutine version)
112 implicit none
113
114 class(array_type), intent(in) :: this
115 real(real32), dimension(:,:), intent(in) :: upstream_grad
116 real(real32), dimension(:,:), intent(out) :: output
117
118 ! Local variables
119 integer :: i, k, c_in, c_out, s
120 integer :: i_in, k_idx, out_idx
121 integer :: input_h, kernel_h, output_h, num_channels, num_filters
122 integer :: stride, dilation
123 real(real32) :: grad_val, kernel_val
124
125
126 ! Unpack parameters
127
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 num_channels = this%indices(1)
128
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 num_filters = this%indices(2)
129
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 stride = this%adj_ja(1,1)
130
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 dilation = this%adj_ja(1,2)
131
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 kernel_h = this%adj_ja(1,3)
132
133
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 input_h = this%left_operand%shape(1)
134
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 output_h = this%shape(1)
135
136
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 2 times.
✓ Branch 20 taken 28 times.
✓ Branch 21 taken 2 times.
32 output = 0._real32
137
138 ! Parallelised over batch, channels, and output positions
139 do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, &
140
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
2 i = 1:output_h, c_out = 1:num_filters)
141 99 out_idx = i + (c_out-1)*output_h
142
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 99 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 99 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 99 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 99 times.
99 grad_val = upstream_grad(out_idx, s)
143
144
9/10
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 27 times.
✓ Branch 3 taken 9 times.
✓ Branch 4 taken 99 times.
✓ Branch 5 taken 27 times.
✓ Branch 6 taken 99 times.
✓ Branch 7 taken 99 times.
✓ Branch 8 taken 99 times.
✗ Branch 9 not taken.
335 if(abs(grad_val) .gt. 1.e-30_real32)then
145
2/2
✓ Branch 0 taken 390 times.
✓ Branch 1 taken 99 times.
489 do k = 1, kernel_h
146 390 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
147
1/2
✓ Branch 0 taken 390 times.
✗ Branch 1 not taken.
489 if(i_in .ge. 1 .and. i_in .le. input_h)then
148 k_idx = k + ( c_in - 1 ) * kernel_h + &
149 390 ( c_out - 1 ) * kernel_h * num_channels
150
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 390 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 390 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 390 times.
390 kernel_val = this%right_operand%val(k_idx, 1)
151
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
780 output(i_in + ( c_in - 1 ) * input_h, s) = &
152
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 390 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 390 times.
780 output(i_in + ( c_in - 1 ) * input_h, s) + &
153
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 390 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 390 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 390 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 390 times.
1950 grad_val * kernel_val
154 end if
155 end do
156 end if
157 end do
158
159 2 end subroutine get_partial_conv1d_input_val
160 !-------------------------------------------------------------------------------
161
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 pure subroutine get_partial_conv1d_kernel_val(this, upstream_grad, output)
162 !! Get partial derivative wrt kernel for 1D convolution (subroutine version)
163 implicit none
164
165 class(array_type), intent(in) :: this
166 real(real32), dimension(:,:), intent(in) :: upstream_grad
167 real(real32), dimension(:,:), intent(out) :: output
168
169 ! Local variables
170 integer :: i, k, c_in, c_out, s
171 integer :: i_in, k_idx, out_idx
172 integer :: input_h, kernel_h, output_h, num_channels, num_filters
173 integer :: stride, dilation
174 real(real32) :: grad_sum
175
176
177 ! Unpack parameters
178
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 num_channels = this%indices(1)
179
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 num_filters = this%indices(2)
180
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
3 stride = this%adj_ja(1,1)
181
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
3 dilation = this%adj_ja(1,2)
182
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
3 kernel_h = this%adj_ja(1,3)
183
184
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 input_h = this%left_operand%shape(1)
185
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 output_h = this%shape(1)
186
187
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 18 taken 3 times.
✓ Branch 19 taken 3 times.
✓ Branch 20 taken 166 times.
✓ Branch 21 taken 3 times.
172 output = 0._real32
188
189 ! Parallelised over filters, channels, and kernel positions
190 3 do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, k = 1:kernel_h)
191 k_idx = k + ( c_in - 1 ) * kernel_h + &
192 166 ( c_out - 1 ) * kernel_h * num_channels
193
194 166 grad_sum = 0._real32
195
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 166 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 166 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 166 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 166 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 166 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 166 times.
✓ Branch 18 taken 166 times.
✓ Branch 19 taken 166 times.
332 do s = 1, size(upstream_grad, dim=2)
196
2/2
✓ Branch 0 taken 606 times.
✓ Branch 1 taken 166 times.
938 do i = 1, output_h
197 606 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
198
1/2
✓ Branch 0 taken 606 times.
✗ Branch 1 not taken.
772 if(i_in .ge. 1 .and. i_in .le. input_h)then
199 606 out_idx = i + ( c_out - 1 ) * output_h
200
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 606 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 606 times.
1212 grad_sum = grad_sum + upstream_grad(out_idx, s) * &
201
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 606 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 606 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 606 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 606 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 606 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 606 times.
1818 this%left_operand%val(i_in + ( c_in - 1 ) * input_h, s)
202 end if
203 end do
204 end do
205
9/12
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 27 times.
✓ Branch 3 taken 9 times.
✓ Branch 4 taken 166 times.
✓ Branch 5 taken 27 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 166 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 166 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 166 times.
205 output(k_idx, 1) = grad_sum
206 end do
207
208 3 end subroutine get_partial_conv1d_kernel_val
209 !###############################################################################
210
211
212 !###############################################################################
213 18 module function conv2d(input, kernel, stride, dilation) result(output)
214 !! 2D convolution operation
215 implicit none
216
217 ! Arguments
218 type(array_type), intent(in), target :: input
219 type(array_type), intent(in), target :: kernel
220 integer, dimension(2), intent(in) :: stride
221 integer, dimension(2), intent(in) :: dilation
222 type(array_type), pointer :: output
223
224 ! Local variables
225 integer :: i, j, ki, kj, c_in, c_out, s
226 integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx
227 integer :: input_h, input_w, kernel_h, kernel_w
228 integer :: output_h, output_w, num_channels, num_filters
229 integer :: channel_size_in, channel_size_out, kernel_channel_size
230 integer :: dil_kernel_h_m1, dil_kernel_w_m1
231 real(real32) :: conv_sum
232 integer, dimension(4) :: output_shape
233
234 ! Extract dimensions
235 ! input: [H_in, W_in, C_in, B]
236 ! kernel: [K_h, K_w, C_in, C_out]
237
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 input_h = input%shape(1)
238
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 input_w = input%shape(2)
239
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 num_channels = input%shape(3)
240
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 kernel_h = kernel%shape(1)
241
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 kernel_w = kernel%shape(2)
242
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 num_filters = kernel%shape(4)
243
244 ! Pre-compute common values
245 18 channel_size_in = input_h * input_w
246 18 kernel_channel_size = kernel_h * kernel_w
247 18 dil_kernel_h_m1 = dilation(1) * (kernel_h - 1)
248 18 dil_kernel_w_m1 = dilation(2) * (kernel_w - 1)
249
250 ! Calculate output dimensions
251 18 output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1
252 18 output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1
253 output_shape = [output_h, output_w, num_filters, &
254
2/2
✓ Branch 0 taken 72 times.
✓ Branch 1 taken 18 times.
90 size(input%val, dim=2)]
255
256 18 output => input%create_result(array_shape = output_shape)
257
4/4
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 732 times.
✓ Branch 3 taken 18 times.
768 output%val = 0._real32
258
259 18 channel_size_out = output_h * output_w
260
261 ! Perform convolution
262 do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, &
263 18 j = 1:output_w, i = 1:output_h)
264 732 conv_sum = 0._real32
265
2/2
✓ Branch 0 taken 1740 times.
✓ Branch 1 taken 732 times.
2472 do c_in = 1, num_channels
266 1740 in_base_idx = (c_in - 1) * channel_size_in
267 k_base_idx = (c_in - 1) * kernel_channel_size + &
268 1740 (c_out - 1) * kernel_channel_size * num_channels
269
2/2
✓ Branch 0 taken 5775 times.
✓ Branch 1 taken 1740 times.
8247 do kj = 1, kernel_w
270 5775 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
271
1/2
✓ Branch 0 taken 5775 times.
✗ Branch 1 not taken.
7515 if(j_in .ge. 1 .and. j_in .le. input_w)then
272
2/2
✓ Branch 0 taken 19607 times.
✓ Branch 1 taken 5775 times.
25382 do ki = 1, kernel_h
273 19607 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
274
1/2
✓ Branch 0 taken 19607 times.
✗ Branch 1 not taken.
25382 if(i_in .ge. 1 .and. i_in .le. input_h)then
275 19607 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
276 19607 k_idx = ki + (kj - 1) * kernel_h + k_base_idx
277 conv_sum = conv_sum + input%val(in_idx, s) * &
278
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 19607 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 19607 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 19607 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 19607 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 19607 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 19607 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 19607 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 19607 times.
19607 kernel%val(k_idx, 1)
279 end if
280 end do
281 end if
282 end do
283 end do
284 732 out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out
285
12/16
✓ Branch 0 taken 42 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 174 times.
✓ Branch 3 taken 42 times.
✓ Branch 4 taken 732 times.
✓ Branch 5 taken 174 times.
✓ Branch 6 taken 732 times.
✓ Branch 7 taken 732 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 732 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 732 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 732 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 732 times.
1698 output%val(out_idx, s) = conv_sum
286 end do
287
288 ! Store parameters for backward pass
289
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
18 allocate(output%indices(2))
290
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 output%indices(1) = num_channels
291
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 output%indices(2) = num_filters
292
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
18 allocate(output%adj_ja(2,3))
293
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✓ Branch 18 taken 36 times.
✓ Branch 19 taken 18 times.
54 output%adj_ja(1:2,1) = stride
294
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✓ Branch 18 taken 36 times.
✓ Branch 19 taken 18 times.
54 output%adj_ja(1:2,2) = dilation
295
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
18 output%adj_ja(1,3) = kernel_h
296
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
18 output%adj_ja(2,3) = kernel_w
297
298
299 18 output%get_partial_left => get_partial_conv2d_input
300 18 output%get_partial_right => get_partial_conv2d_kernel
301 18 output%get_partial_left_val => get_partial_conv2d_input_val
302 18 output%get_partial_right_val => get_partial_conv2d_kernel_val
303
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 if(input%requires_grad .or. kernel%requires_grad)then
304 18 output%requires_grad = .true.
305 18 output%is_forward = input%is_forward
306 18 output%operation = 'conv2d'
307 18 output%left_operand => input
308 18 output%right_operand => kernel
309 end if
310
311 18 end function conv2d
312 !-------------------------------------------------------------------------------
313 function get_partial_conv2d_input(this, upstream_grad) result(output)
314 !! Get partial derivative wrt input for 2D convolution
315 implicit none
316
317 class(array_type), intent(inout) :: this
318 type(array_type), intent(in) :: upstream_grad
319 type(array_type) :: output
320
321 call output%allocate(array_shape = [ this%left_operand%shape, &
322 size(this%left_operand%val, dim=2) ])
323 call this%get_partial_left_val(upstream_grad%val, output%val)
324
325 end function get_partial_conv2d_input
326 !-------------------------------------------------------------------------------
327 function get_partial_conv2d_kernel(this, upstream_grad) result(output)
328 !! Get partial derivative wrt kernel for 2D convolution
329 implicit none
330
331 class(array_type), intent(inout) :: this
332 type(array_type), intent(in) :: upstream_grad
333 type(array_type) :: output
334
335 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
336 call this%get_partial_right_val(upstream_grad%val, output%val)
337
338 end function get_partial_conv2d_kernel
339 !-------------------------------------------------------------------------------
340
2/4
✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13 times.
✗ Branch 3 not taken.
13 pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output)
341 !! Get partial derivative wrt input for 2D convolution (subroutine version)
342 implicit none
343
344 class(array_type), intent(in) :: this
345 real(real32), dimension(:,:), intent(in) :: upstream_grad
346 real(real32), dimension(:,:), intent(out) :: output
347
348 ! Local variables
349 integer :: i, j, ki, kj, c_in, c_out, s
350 integer :: i_in, j_in, k_idx, out_idx, in_idx
351 integer :: in_base_idx, k_base_idx, kernel_channel_size
352 integer :: input_h, input_w, kernel_h, kernel_w
353 integer :: output_h, output_w, num_channels, num_filters
354 integer, dimension(2) :: stride, dilation
355 integer :: channel_size_in, channel_size_out
356 real(real32) :: grad_val, kernel_val
357
358
359 ! Unpack parameters
360
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 num_channels = this%indices(1)
361
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 num_filters = this%indices(2)
362
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 26 times.
✓ Branch 19 taken 13 times.
39 stride = this%adj_ja(1:2,1)
363
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 26 times.
✓ Branch 19 taken 13 times.
39 dilation = this%adj_ja(1:2,2)
364
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 kernel_h = this%adj_ja(1,3)
365
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 kernel_w = this%adj_ja(2,3)
366
367
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 input_h = this%left_operand%shape(1)
368
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 input_w = this%left_operand%shape(2)
369
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 output_h = this%shape(1)
370
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 output_w = this%shape(2)
371 13 channel_size_in = input_h * input_w
372 13 channel_size_out = output_h * output_w
373 13 kernel_channel_size = kernel_h * kernel_w
374
375
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
✓ Branch 20 taken 159 times.
✓ Branch 21 taken 13 times.
185 output = 0._real32
376
377 ! Parallelised over batch, output channels, output spatial dims
378 do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, &
379
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 13 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 13 times.
13 j = 1:output_w, i = 1:output_h)
380 84 out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out
381
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 84 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 84 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 84 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 84 times.
84 grad_val = upstream_grad(out_idx, s)
382
383
10/10
✓ Branch 0 taken 15 times.
✓ Branch 1 taken 13 times.
✓ Branch 2 taken 21 times.
✓ Branch 3 taken 15 times.
✓ Branch 4 taken 84 times.
✓ Branch 5 taken 21 times.
✓ Branch 6 taken 84 times.
✓ Branch 7 taken 84 times.
✓ Branch 8 taken 82 times.
✓ Branch 9 taken 2 times.
299 if(abs(grad_val) .gt. 1.e-30_real32)then
384
2/2
✓ Branch 0 taken 298 times.
✓ Branch 1 taken 82 times.
380 do c_in = 1, num_channels
385 298 in_base_idx = (c_in - 1) * channel_size_in
386 k_base_idx = (c_in - 1) * kernel_channel_size + &
387 298 (c_out - 1) * kernel_channel_size * num_channels
388
389
2/2
✓ Branch 0 taken 1163 times.
✓ Branch 1 taken 298 times.
1543 do kj = 1, kernel_w
390 1163 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
391
1/2
✓ Branch 0 taken 1163 times.
✗ Branch 1 not taken.
1461 if(j_in .ge. 1 .and. j_in .le. input_w)then
392
2/2
✓ Branch 0 taken 4621 times.
✓ Branch 1 taken 1163 times.
5784 do ki = 1, kernel_h
393 4621 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
394
1/2
✓ Branch 0 taken 4621 times.
✗ Branch 1 not taken.
5784 if(i_in .ge. 1 .and. i_in .le. input_h)then
395 4621 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
396 k_idx = (kernel_h - ki + 1) + &
397 4621 (kernel_w - kj) * kernel_h + k_base_idx
398
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 4621 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4621 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4621 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4621 times.
4621 kernel_val = this%right_operand%val(k_idx, 1)
399
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 4621 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4621 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4621 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4621 times.
18484 output(in_idx, s) = output(in_idx, s) + &
400
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 4621 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4621 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4621 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4621 times.
23105 grad_val * kernel_val
401 end if
402 end do
403 end if
404 end do
405 end do
406 end if
407 end do
408
409 13 end subroutine get_partial_conv2d_input_val
410 !-------------------------------------------------------------------------------
411
2/4
✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
14 pure subroutine get_partial_conv2d_kernel_val(this, upstream_grad, output)
412 !! Get partial derivative wrt kernel for 2D convolution (subroutine version)
413 implicit none
414
415 class(array_type), intent(in) :: this
416 real(real32), dimension(:,:), intent(in) :: upstream_grad
417 real(real32), dimension(:,:), intent(out) :: output
418
419 ! Local variables
420 integer :: i, j, ki, kj, c_in, c_out, s
421 integer :: i_in, j_in, k_idx, out_idx, in_idx
422 integer :: in_base_idx, out_base_idx, k_base_idx, kernel_channel_size
423 integer :: input_h, input_w, kernel_h, kernel_w
424 integer :: output_h, output_w, num_channels, num_filters
425 integer, dimension(2) :: stride, dilation
426 integer :: channel_size_in, channel_size_out
427 real(real32) :: grad_sum
428
429
430 ! Unpack parameters
431
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 num_channels = this%indices(1)
432
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 num_filters = this%indices(2)
433
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 28 times.
✓ Branch 19 taken 14 times.
42 stride = this%adj_ja(1:2,1)
434
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 28 times.
✓ Branch 19 taken 14 times.
42 dilation = this%adj_ja(1:2,2)
435
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
14 kernel_h = this%adj_ja(1,3)
436
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
14 kernel_w = this%adj_ja(2,3)
437
438
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 input_h = this%left_operand%shape(1)
439
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 input_w = this%left_operand%shape(2)
440
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 output_h = this%shape(1)
441
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 output_w = this%shape(2)
442 14 channel_size_in = input_h * input_w
443 14 channel_size_out = output_h * output_w
444 14 kernel_channel_size = kernel_h * kernel_w
445
446
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 14 times.
✓ Branch 20 taken 635 times.
✓ Branch 21 taken 14 times.
663 output = 0._real32
447
448 ! Parallelised over filters, channels, and kernel dimensions
449 do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, &
450 14 kj = 1:kernel_w, ki = 1:kernel_h)
451 635 out_base_idx = (c_out - 1) * channel_size_out
452 635 in_base_idx = (c_in - 1) * channel_size_in
453 k_base_idx = (c_in - 1) * kernel_channel_size + &
454 635 (c_out - 1) * kernel_channel_size * num_channels
455 635 k_idx = ki + (kj - 1) * kernel_h + k_base_idx
456
457 635 grad_sum = 0._real32
458
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 635 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 635 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 635 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 635 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 635 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 635 times.
✓ Branch 18 taken 635 times.
✓ Branch 19 taken 635 times.
1270 do s = 1, size(upstream_grad, dim=2)
459
2/2
✓ Branch 0 taken 2199 times.
✓ Branch 1 taken 635 times.
3469 do j = 1, output_w
460 2199 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
461
1/2
✓ Branch 0 taken 2199 times.
✗ Branch 1 not taken.
2834 if(j_in .ge. 1 .and. j_in .le. input_w)then
462
2/2
✓ Branch 0 taken 8511 times.
✓ Branch 1 taken 2199 times.
10710 do i = 1, output_h
463 8511 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
464
1/2
✓ Branch 0 taken 8511 times.
✗ Branch 1 not taken.
10710 if(i_in .ge. 1 .and. i_in .le. input_h)then
465 8511 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
466 8511 out_idx = i + (j - 1) * output_h + out_base_idx
467 grad_sum = grad_sum + &
468
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 8511 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8511 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8511 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8511 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8511 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8511 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8511 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8511 times.
8511 upstream_grad(out_idx, s) * this%left_operand%val(in_idx, s)
469 end if
470 end do
471 end if
472 end do
473 end do
474
11/14
✓ Branch 0 taken 20 times.
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 40 times.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 106 times.
✓ Branch 5 taken 40 times.
✓ Branch 6 taken 635 times.
✓ Branch 7 taken 106 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 635 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 635 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 635 times.
815 output(k_idx, 1) = grad_sum
475 end do
476
477 14 end subroutine get_partial_conv2d_kernel_val
478 !###############################################################################
479
480
481 !###############################################################################
482 16 module function conv3d(input, kernel, stride, dilation) result(output)
483 !! 3D convolution operation
484 implicit none
485
486 ! Arguments
487 type(array_type), intent(in), target :: input
488 type(array_type), intent(in), target :: kernel
489 integer, dimension(3), intent(in) :: stride
490 integer, dimension(3), intent(in) :: dilation
491 type(array_type), pointer :: output
492
493 ! Local variables
494 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
495 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
496 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
497 integer :: output_h, output_w, output_d
498 integer :: num_channels, num_filters
499 integer :: channel_size_in, channel_size_out
500 real(real32) :: conv_sum
501 integer, dimension(5) :: output_shape
502
503 ! Extract dimensions
504 ! input: [H_in, W_in, D_in, C_in, B]
505 ! kernel: [K_h, K_w, K_d, C_in, C_out]
506
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 input_h = input%shape(1)
507
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 input_w = input%shape(2)
508
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 input_d = input%shape(3)
509
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 num_channels = input%shape(4)
510
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 kernel_h = kernel%shape(1)
511
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 kernel_w = kernel%shape(2)
512
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 kernel_d = kernel%shape(3)
513
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 num_filters = kernel%shape(5)
514
515 ! Calculate output dimensions
516 output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / &
517 16 stride(1) + 1
518 output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / &
519 16 stride(2) + 1
520 output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / &
521 16 stride(3) + 1
522 output_shape = [output_h, output_w, output_d, num_filters, &
523
2/2
✓ Branch 0 taken 80 times.
✓ Branch 1 taken 16 times.
96 size(input%val, dim=2)]
524
525 16 output => input%create_result(array_shape = output_shape)
526
4/4
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 2172 times.
✓ Branch 3 taken 16 times.
2204 output%val = 0._real32
527
528 16 channel_size_in = input_h * input_w * input_d
529 16 channel_size_out = output_h * output_w * output_d
530
531 ! Perform convolution - optimised with do concurrent
532 do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, &
533 16 k = 1:output_d, j = 1:output_w, i = 1:output_h)
534
535 2172 conv_sum = 0._real32
536
2/2
✓ Branch 0 taken 6924 times.
✓ Branch 1 taken 2172 times.
9096 do c_in = 1, num_channels
537
2/2
✓ Branch 0 taken 22479 times.
✓ Branch 1 taken 6924 times.
31575 do kk = 1, kernel_d
538 22479 k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1
539
1/2
✓ Branch 0 taken 22479 times.
✗ Branch 1 not taken.
29403 if(k_in .ge. 1 .and. k_in .le. input_d)then
540
2/2
✓ Branch 0 taken 74327 times.
✓ Branch 1 taken 22479 times.
96806 do kj = 1, kernel_w
541 74327 j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1
542
1/2
✓ Branch 0 taken 74327 times.
✗ Branch 1 not taken.
96806 if(j_in .ge. 1 .and. j_in .le. input_w)then
543
2/2
✓ Branch 0 taken 250605 times.
✓ Branch 1 taken 74327 times.
324932 do ki = 1, kernel_h
544 i_in = ( i - 1 ) * stride(1) + &
545 250605 ( ki - 1 ) * dilation(1) + 1
546
1/2
✓ Branch 0 taken 250605 times.
✗ Branch 1 not taken.
324932 if(i_in .ge. 1 .and. i_in .le. input_h)then
547 in_idx = i_in + ( j_in - 1 ) * input_h + &
548 ( k_in - 1 ) * input_h * input_w + &
549 250605 ( c_in - 1 ) * channel_size_in
550 k_idx = ki + ( kj - 1 ) * kernel_h + &
551 ( kk - 1 ) * kernel_h * kernel_w + &
552 ( c_in - 1 ) * kernel_h * kernel_w * &
553 kernel_d + &
554 ( c_out - 1 ) * kernel_h * kernel_w * &
555 250605 kernel_d * num_channels
556 conv_sum = conv_sum + input%val(in_idx, s) * &
557
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 250605 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 250605 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 250605 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 250605 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 250605 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 250605 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 250605 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 250605 times.
250605 kernel%val(k_idx, 1)
558 end if
559 end do
560 end if
561 end do
562 end if
563 end do
564 end do
565 out_idx = i + ( j - 1 ) * output_h + &
566 ( k - 1 ) * output_h * output_w + &
567 2172 ( c_out - 1 ) * channel_size_out
568
14/18
✓ Branch 0 taken 30 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 102 times.
✓ Branch 3 taken 30 times.
✓ Branch 4 taken 498 times.
✓ Branch 5 taken 102 times.
✓ Branch 6 taken 2172 times.
✓ Branch 7 taken 498 times.
✓ Branch 8 taken 2172 times.
✓ Branch 9 taken 2172 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2172 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2172 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2172 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2172 times.
4990 output%val(out_idx, s) = conv_sum
569 end do
570
571 ! Store parameters for backward pass
572
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
16 allocate(output%indices(2))
573
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 output%indices(1) = num_channels
574
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 output%indices(2) = num_filters
575
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
16 allocate(output%adj_ja(3,3))
576
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✓ Branch 18 taken 48 times.
✓ Branch 19 taken 16 times.
64 output%adj_ja(1:3,1) = stride
577
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✓ Branch 18 taken 48 times.
✓ Branch 19 taken 16 times.
64 output%adj_ja(1:3,2) = dilation
578
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 output%adj_ja(1,3) = kernel_h
579
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 output%adj_ja(2,3) = kernel_w
580
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 output%adj_ja(3,3) = kernel_d
581
582 16 output%get_partial_left => get_partial_conv3d_input
583 16 output%get_partial_right => get_partial_conv3d_kernel
584 16 output%get_partial_left_val => get_partial_conv3d_input_val
585 16 output%get_partial_right_val => get_partial_conv3d_kernel_val
586
1/2
✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
16 if(input%requires_grad .or. kernel%requires_grad)then
587 16 output%requires_grad = .true.
588 16 output%is_forward = input%is_forward
589 16 output%operation = 'conv3d'
590 16 output%left_operand => input
591 16 output%right_operand => kernel
592 end if
593
594 16 end function conv3d
595 !-------------------------------------------------------------------------------
596 function get_partial_conv3d_input(this, upstream_grad) result(output)
597 !! Get partial derivative wrt input for 3D convolution
598 implicit none
599
600 class(array_type), intent(inout) :: this
601 type(array_type), intent(in) :: upstream_grad
602 type(array_type) :: output
603
604 call output%allocate(array_shape = [ this%left_operand%shape, &
605 size(this%left_operand%val, dim=2) ])
606 call this%get_partial_left_val(upstream_grad%val, output%val)
607
608 end function get_partial_conv3d_input
609 !-------------------------------------------------------------------------------
610 function get_partial_conv3d_kernel(this, upstream_grad) result(output)
611 !! Get partial derivative wrt kernel for 3D convolution
612 implicit none
613
614 class(array_type), intent(inout) :: this
615 type(array_type), intent(in) :: upstream_grad
616 type(array_type) :: output
617
618 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
619 call this%get_partial_right_val(upstream_grad%val, output%val)
620
621 end function get_partial_conv3d_kernel
622 !-------------------------------------------------------------------------------
623
2/4
✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 13 times.
✗ Branch 3 not taken.
13 pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output)
624 !! Get partial derivative wrt input for 3D convolution (subroutine version)
625 implicit none
626
627 class(array_type), intent(in) :: this
628 real(real32), dimension(:,:), intent(in) :: upstream_grad
629 real(real32), dimension(:,:), intent(out) :: output
630
631 ! Local variables
632 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
633 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
634 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
635 integer :: output_h, output_w, output_d
636 integer :: num_channels, num_filters
637 integer :: channel_size_in, channel_size_out
638 integer, dimension(3) :: stride, dilation
639 real(real32) :: grad_val, kernel_val
640
641
642 ! Unpack parameters
643
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 num_channels = this%indices(1)
644
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 num_filters = this%indices(2)
645
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 39 times.
✓ Branch 19 taken 13 times.
52 stride = this%adj_ja(1:3,1)
646
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 39 times.
✓ Branch 19 taken 13 times.
52 dilation = this%adj_ja(1:3,2)
647
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 kernel_h = this%adj_ja(1,3)
648
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 kernel_w = this%adj_ja(2,3)
649
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 kernel_d = this%adj_ja(3,3)
650
651
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 input_h = this%left_operand%shape(1)
652
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 input_w = this%left_operand%shape(2)
653
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 input_d = this%left_operand%shape(3)
654
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 output_h = this%shape(1)
655
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 output_w = this%shape(2)
656
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
13 output_d = this%shape(3)
657
658
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
✓ Branch 20 taken 883 times.
✓ Branch 21 taken 13 times.
909 output = 0._real32
659
660 13 channel_size_in = input_h * input_w * input_d
661 13 channel_size_out = output_h * output_w * output_d
662
663
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 13 times.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 13 times.
26 do s = 1, size(upstream_grad, dim=2)
664
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 13 times.
42 do c_in = 1, num_channels
665
2/2
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 16 times.
53 do k = 1, output_d
666
2/2
✓ Branch 0 taken 48 times.
✓ Branch 1 taken 24 times.
88 do j = 1, output_w
667
2/2
✓ Branch 0 taken 120 times.
✓ Branch 1 taken 48 times.
192 do i = 1, output_h
668
2/2
✓ Branch 0 taken 876 times.
✓ Branch 1 taken 120 times.
1044 do c_out = 1, num_filters
669 out_idx = i + ( j - 1 ) * output_h + &
670 ( k - 1 ) * output_h * output_w + &
671 876 ( c_out - 1 ) * channel_size_out
672
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 876 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 876 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 876 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 876 times.
876 grad_val = upstream_grad(out_idx, s)
673
674
2/2
✓ Branch 0 taken 3469 times.
✓ Branch 1 taken 876 times.
4465 do kk = 1, kernel_d
675 k_in = ( k - 1 ) * stride(3) + &
676 3469 ( kk - 1 ) * dilation(3) + 1
677
1/2
✓ Branch 0 taken 3469 times.
✗ Branch 1 not taken.
4345 if( k_in .ge. 1 .and. k_in .le. input_d )then
678
2/2
✓ Branch 0 taken 13839 times.
✓ Branch 1 taken 3469 times.
17308 do kj = 1, kernel_w
679 j_in = ( j - 1 ) * stride(2) + &
680 13839 ( kj - 1 ) * dilation(2) + 1
681
1/2
✓ Branch 0 taken 13839 times.
✗ Branch 1 not taken.
17308 if( j_in .ge. 1 .and. j_in .le. input_w )then
682
2/2
✓ Branch 0 taken 55315 times.
✓ Branch 1 taken 13839 times.
69154 do ki = 1, kernel_h
683 i_in = ( i - 1 ) * stride(1) + &
684 55315 ( ki - 1 ) * dilation(1) + 1
685
1/2
✓ Branch 0 taken 55315 times.
✗ Branch 1 not taken.
55315 if( i_in .ge. 1 .and. &
686 13839 i_in .le. input_h )then
687 in_idx = i_in + &
688 ( j_in - 1 ) * input_h + &
689 ( k_in - 1 ) * input_h * input_w + &
690 55315 ( c_in - 1 ) * channel_size_in
691 k_idx = ki + ( kj - 1 ) * kernel_h + &
692 ( kk - 1 ) * kernel_h * kernel_w + &
693 ( c_in - 1 ) * kernel_h * &
694 kernel_w * kernel_d + &
695 ( c_out - 1 ) * kernel_h * &
696 55315 kernel_w * kernel_d * num_channels
697
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 55315 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 55315 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 55315 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 55315 times.
55315 kernel_val = this%right_operand%val(k_idx, 1)
698
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 55315 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 55315 times.
110630 output(in_idx, s) = &
699
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 55315 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 55315 times.
110630 output(in_idx, s) + &
700
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 55315 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 55315 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 55315 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 55315 times.
276575 grad_val * kernel_val
701 end if
702 end do
703 end if
704 end do
705 end if
706 end do
707 end do
708 end do
709 end do
710 end do
711 end do
712 end do
713
714 13 end subroutine get_partial_conv3d_input_val
715 !-------------------------------------------------------------------------------
716
2/4
✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
14 pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output)
717 !! Get partial derivative wrt kernel for 3D convolution (subroutine version)
718 implicit none
719
720 class(array_type), intent(in) :: this
721 real(real32), dimension(:,:), intent(in) :: upstream_grad
722 real(real32), dimension(:,:), intent(out) :: output
723
724 ! Local variables
725 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
726 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
727 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
728 integer :: output_h, output_w, output_d
729 integer :: num_channels, num_filters
730 integer :: channel_size_in, channel_size_out
731 integer, dimension(3) :: stride, dilation
732 real(real32) :: grad_sum
733
734
735 ! Unpack parameters
736
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 num_channels = this%indices(1)
737
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 num_filters = this%indices(2)
738
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 42 times.
✓ Branch 19 taken 14 times.
56 stride = this%adj_ja(1:3,1)
739
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 42 times.
✓ Branch 19 taken 14 times.
56 dilation = this%adj_ja(1:3,2)
740
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
14 kernel_h = this%adj_ja(1,3)
741
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
14 kernel_w = this%adj_ja(2,3)
742
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
14 kernel_d = this%adj_ja(3,3)
743
744
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 input_h = this%left_operand%shape(1)
745
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 input_w = this%left_operand%shape(2)
746
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 input_d = this%left_operand%shape(3)
747
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 output_h = this%shape(1)
748
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 output_w = this%shape(2)
749
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 output_d = this%shape(3)
750
751
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 14 times.
✓ Branch 20 taken 2391 times.
✓ Branch 21 taken 14 times.
2419 output = 0._real32
752
753 14 channel_size_in = input_h * input_w * input_d
754 14 channel_size_out = output_h * output_w * output_d
755
756
2/2
✓ Branch 0 taken 24 times.
✓ Branch 1 taken 14 times.
38 do c_out = 1, num_filters
757
2/2
✓ Branch 0 taken 56 times.
✓ Branch 1 taken 24 times.
94 do c_in = 1, num_channels
758
2/2
✓ Branch 0 taken 177 times.
✓ Branch 1 taken 56 times.
257 do kk = 1, kernel_d
759
2/2
✓ Branch 0 taken 635 times.
✓ Branch 1 taken 177 times.
868 do kj = 1, kernel_w
760
2/2
✓ Branch 0 taken 2391 times.
✓ Branch 1 taken 635 times.
3203 do ki = 1, kernel_h
761 k_idx = ki + (kj-1)*kernel_h + &
762 (kk-1)*kernel_h*kernel_w + &
763 (c_in-1)*kernel_h*kernel_w*kernel_d + &
764 2391 (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels
765
766 2391 grad_sum = 0._real32
767
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 2391 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2391 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2391 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2391 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2391 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2391 times.
✓ Branch 18 taken 2391 times.
✓ Branch 19 taken 2391 times.
4782 do s = 1, size(upstream_grad, dim=2)
768
2/2
✓ Branch 0 taken 8107 times.
✓ Branch 1 taken 2391 times.
12889 do k = 1, output_d
769 8107 k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1
770
1/2
✓ Branch 0 taken 8107 times.
✗ Branch 1 not taken.
10498 if(k_in .ge. 1 .and. k_in .le. input_d)then
771
2/2
✓ Branch 0 taken 30115 times.
✓ Branch 1 taken 8107 times.
38222 do j = 1, output_w
772 j_in = (j-1)*stride(2) + &
773 30115 (kj-1)*dilation(2) + 1
774
1/2
✓ Branch 0 taken 30115 times.
✗ Branch 1 not taken.
38222 if(j_in .ge. 1 .and. j_in .le. input_w)then
775
2/2
✓ Branch 0 taken 125299 times.
✓ Branch 1 taken 30115 times.
155414 do i = 1, output_h
776 i_in = (i-1)*stride(1) + &
777 125299 (ki-1)*dilation(1) + 1
778
1/2
✓ Branch 0 taken 125299 times.
✗ Branch 1 not taken.
155414 if(i_in .ge. 1 .and. i_in .le. input_h)then
779 in_idx = i_in + (j_in-1)*input_h + &
780 (k_in-1)*input_h*input_w + &
781 125299 (c_in-1)*channel_size_in
782 out_idx = i + (j-1)*output_h + &
783 (k-1)*output_h*output_w + &
784 125299 (c_out-1)*channel_size_out
785 grad_sum = grad_sum + &
786
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 125299 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 125299 times.
250598 upstream_grad(out_idx, s) * &
787
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 125299 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 125299 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 125299 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 125299 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 125299 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 125299 times.
375897 this%left_operand%val(in_idx, s)
788 end if
789 end do
790 end if
791 end do
792 end if
793 end do
794 end do
795
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2391 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2391 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2391 times.
3026 output(k_idx, 1) = grad_sum
796 end do
797 end do
798 end do
799 end do
800 end do
801
802 14 end subroutine get_partial_conv3d_kernel_val
803 !###############################################################################
804
805 end submodule athena__diffstruc_extd_submodule_conv
806