GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_conv.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_conv
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function conv1d(input, kernel, stride, dilation) result(output)
8 !! 1D convolution operation
9 implicit none
10
11 ! Arguments
12 type(array_type), intent(in), target :: input
13 type(array_type), intent(in), target :: kernel
14 integer, intent(in) :: stride
15 integer, intent(in) :: dilation
16 type(array_type), pointer :: output
17
18 ! Local variables
19 integer :: i, k, c_in, c_out, s
20 integer :: i_in, k_idx
21 integer :: input_h, kernel_h, output_h, num_channels, num_filters
22 real(real32) :: conv_sum
23 integer, dimension(3) :: output_shape
24
25 ! Extract dimensions
26 ! input: [H_in, C_in, B]
27 ! kernel: [K, C_in, C_out]
28 input_h = input%shape(1)
29 num_channels = input%shape(2)
30 kernel_h = kernel%shape(1)
31 num_filters = kernel%shape(3)
32
33 ! Calculate output dimensions
34 output_h = (input_h - dilation*(kernel_h - 1) - 1) / &
35 stride + 1
36 output_shape = [output_h, num_filters, size(input%val, dim=2)]
37
38 output => input%create_result(array_shape = output_shape)
39 output%val = 0._real32
40
41 ! Perform convolution
42 do concurrent(s = 1:output_shape(3), c_out = 1:num_filters, &
43 i = 1:output_h)
44 conv_sum = 0._real32
45 do c_in = 1, num_channels
46 do k = 1, kernel_h
47 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
48 if(i_in .ge. 1 .and. i_in .le. input_h)then
49 k_idx = k + ( c_in - 1 ) * kernel_h + &
50 ( c_out - 1 ) * kernel_h * num_channels
51 conv_sum = conv_sum + &
52 input%val(i_in + ( c_in - 1 ) * input_h, s) * &
53 kernel%val(k_idx, 1)
54 end if
55 end do
56 end do
57 output%val(i + (c_out-1)*output_h, s) = conv_sum
58 end do
59
60 ! Store parameters for backward pass
61 allocate(output%indices(2))
62 output%indices(1) = num_channels
63 output%indices(2) = num_filters
64 allocate(output%adj_ja(1,3))
65 output%adj_ja(1,1) = stride
66 output%adj_ja(1,2) = dilation
67 output%adj_ja(1,3) = kernel_h
68
69 output%get_partial_left => get_partial_conv1d_input
70 output%get_partial_right => get_partial_conv1d_kernel
71 output%get_partial_left_val => get_partial_conv1d_input_val
72 output%get_partial_right_val => get_partial_conv1d_kernel_val
73 if(input%requires_grad .or. kernel%requires_grad)then
74 output%requires_grad = .true.
75 output%is_forward = input%is_forward
76 output%operation = 'conv1d'
77 output%left_operand => input
78 output%right_operand => kernel
79 end if
80
81 end function conv1d
82 !-------------------------------------------------------------------------------
83 function get_partial_conv1d_input(this, upstream_grad) result(output)
84 !! Get partial derivative wrt input for 1D convolution
85 implicit none
86
87 class(array_type), intent(inout) :: this
88 type(array_type), intent(in) :: upstream_grad
89 type(array_type) :: output
90
91 call output%allocate(array_shape = [ this%left_operand%shape, &
92 size(this%left_operand%val, dim=2) ])
93 call this%get_partial_left_val(upstream_grad%val, output%val)
94
95 end function get_partial_conv1d_input
96 !-------------------------------------------------------------------------------
97 function get_partial_conv1d_kernel(this, upstream_grad) result(output)
98 !! Get partial derivative wrt kernel for 1D convolution
99 implicit none
100
101 class(array_type), intent(inout) :: this
102 type(array_type), intent(in) :: upstream_grad
103 type(array_type) :: output
104
105 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
106 call this%get_partial_right_val(upstream_grad%val, output%val)
107
108 end function get_partial_conv1d_kernel
109 !-------------------------------------------------------------------------------
110 pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output)
111 !! Get partial derivative wrt input for 1D convolution (subroutine version)
112 implicit none
113
114 class(array_type), intent(in) :: this
115 real(real32), dimension(:,:), intent(in) :: upstream_grad
116 real(real32), dimension(:,:), intent(out) :: output
117
118 ! Local variables
119 integer :: i, k, c_in, c_out, s
120 integer :: i_in, k_idx, out_idx
121 integer :: input_h, kernel_h, output_h, num_channels, num_filters
122 integer :: stride, dilation
123 real(real32) :: grad_val, kernel_val
124
125
126 ! Unpack parameters
127 num_channels = this%indices(1)
128 num_filters = this%indices(2)
129 stride = this%adj_ja(1,1)
130 dilation = this%adj_ja(1,2)
131 kernel_h = this%adj_ja(1,3)
132
133 input_h = this%left_operand%shape(1)
134 output_h = this%shape(1)
135
136 output = 0._real32
137
138 ! Parallelised over batch, channels, and output positions
139 do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, &
140 i = 1:output_h, c_out = 1:num_filters)
141 out_idx = i + (c_out-1)*output_h
142 grad_val = upstream_grad(out_idx, s)
143
144 if(abs(grad_val) > 1.e-30_real32)then
145 do k = 1, kernel_h
146 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
147 if(i_in .ge. 1 .and. i_in .le. input_h)then
148 k_idx = k + ( c_in - 1 ) * kernel_h + &
149 ( c_out - 1 ) * kernel_h * num_channels
150 kernel_val = this%right_operand%val(k_idx, 1)
151 output(i_in + ( c_in - 1 ) * input_h, s) = &
152 output(i_in + ( c_in - 1 ) * input_h, s) + &
153 grad_val * kernel_val
154 end if
155 end do
156 end if
157 end do
158
159 end subroutine get_partial_conv1d_input_val
160 !-------------------------------------------------------------------------------
161 pure subroutine get_partial_conv1d_kernel_val(this, upstream_grad, output)
162 !! Get partial derivative wrt kernel for 1D convolution (subroutine version)
163 implicit none
164
165 class(array_type), intent(in) :: this
166 real(real32), dimension(:,:), intent(in) :: upstream_grad
167 real(real32), dimension(:,:), intent(out) :: output
168
169 ! Local variables
170 integer :: i, k, c_in, c_out, s
171 integer :: i_in, k_idx, out_idx
172 integer :: input_h, kernel_h, output_h, num_channels, num_filters
173 integer :: stride, dilation
174 real(real32) :: grad_sum
175
176
177 ! Unpack parameters
178 num_channels = this%indices(1)
179 num_filters = this%indices(2)
180 stride = this%adj_ja(1,1)
181 dilation = this%adj_ja(1,2)
182 kernel_h = this%adj_ja(1,3)
183
184 input_h = this%left_operand%shape(1)
185 output_h = this%shape(1)
186
187 output = 0._real32
188
189 ! Parallelised over filters, channels, and kernel positions
190 do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, k = 1:kernel_h)
191 k_idx = k + ( c_in - 1 ) * kernel_h + &
192 ( c_out - 1 ) * kernel_h * num_channels
193
194 grad_sum = 0._real32
195 do s = 1, size(upstream_grad, dim=2)
196 do i = 1, output_h
197 i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1
198 if(i_in .ge. 1 .and. i_in .le. input_h)then
199 out_idx = i + ( c_out - 1 ) * output_h
200 grad_sum = grad_sum + upstream_grad(out_idx, s) * &
201 this%left_operand%val(i_in + ( c_in - 1 ) * input_h, s)
202 end if
203 end do
204 end do
205 output(k_idx, 1) = grad_sum
206 end do
207
208 end subroutine get_partial_conv1d_kernel_val
209 !###############################################################################
210
211
212 !###############################################################################
213 module function conv2d(input, kernel, stride, dilation) result(output)
214 !! 2D convolution operation
215 implicit none
216
217 ! Arguments
218 type(array_type), intent(in), target :: input
219 type(array_type), intent(in), target :: kernel
220 integer, dimension(2), intent(in) :: stride
221 integer, dimension(2), intent(in) :: dilation
222 type(array_type), pointer :: output
223
224 ! Local variables
225 integer :: i, j, ki, kj, c_in, c_out, s
226 integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx
227 integer :: input_h, input_w, kernel_h, kernel_w
228 integer :: output_h, output_w, num_channels, num_filters
229 integer :: channel_size_in, channel_size_out, kernel_channel_size
230 integer :: dil_kernel_h_m1, dil_kernel_w_m1
231 real(real32) :: conv_sum
232 integer, dimension(4) :: output_shape
233
234 ! Extract dimensions
235 ! input: [H_in, W_in, C_in, B]
236 ! kernel: [K_h, K_w, C_in, C_out]
237 input_h = input%shape(1)
238 input_w = input%shape(2)
239 num_channels = input%shape(3)
240 kernel_h = kernel%shape(1)
241 kernel_w = kernel%shape(2)
242 num_filters = kernel%shape(4)
243
244 ! Pre-compute common values
245 channel_size_in = input_h * input_w
246 kernel_channel_size = kernel_h * kernel_w
247 dil_kernel_h_m1 = dilation(1) * (kernel_h - 1)
248 dil_kernel_w_m1 = dilation(2) * (kernel_w - 1)
249
250 ! Calculate output dimensions
251 output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1
252 output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1
253 output_shape = [output_h, output_w, num_filters, &
254 size(input%val, dim=2)]
255
256 output => input%create_result(array_shape = output_shape)
257 output%val = 0._real32
258
259 channel_size_out = output_h * output_w
260
261 ! Perform convolution
262 do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, &
263 j = 1:output_w, i = 1:output_h)
264 conv_sum = 0._real32
265 do c_in = 1, num_channels
266 in_base_idx = (c_in - 1) * channel_size_in
267 k_base_idx = (c_in - 1) * kernel_channel_size + &
268 (c_out - 1) * kernel_channel_size * num_channels
269 do kj = 1, kernel_w
270 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
271 if(j_in .ge. 1 .and. j_in .le. input_w)then
272 do ki = 1, kernel_h
273 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
274 if(i_in .ge. 1 .and. i_in .le. input_h)then
275 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
276 k_idx = ki + (kj - 1) * kernel_h + k_base_idx
277 conv_sum = conv_sum + input%val(in_idx, s) * &
278 kernel%val(k_idx, 1)
279 end if
280 end do
281 end if
282 end do
283 end do
284 out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out
285 output%val(out_idx, s) = conv_sum
286 end do
287
288 ! Store parameters for backward pass
289 allocate(output%indices(2))
290 output%indices(1) = num_channels
291 output%indices(2) = num_filters
292 allocate(output%adj_ja(2,3))
293 output%adj_ja(1:2,1) = stride
294 output%adj_ja(1:2,2) = dilation
295 output%adj_ja(1,3) = kernel_h
296 output%adj_ja(2,3) = kernel_w
297
298
299 output%get_partial_left => get_partial_conv2d_input
300 output%get_partial_right => get_partial_conv2d_kernel
301 output%get_partial_left_val => get_partial_conv2d_input_val
302 output%get_partial_right_val => get_partial_conv2d_kernel_val
303 if(input%requires_grad .or. kernel%requires_grad)then
304 output%requires_grad = .true.
305 output%is_forward = input%is_forward
306 output%operation = 'conv2d'
307 output%left_operand => input
308 output%right_operand => kernel
309 end if
310
311 end function conv2d
312 !-------------------------------------------------------------------------------
313 function get_partial_conv2d_input(this, upstream_grad) result(output)
314 !! Get partial derivative wrt input for 2D convolution
315 implicit none
316
317 class(array_type), intent(inout) :: this
318 type(array_type), intent(in) :: upstream_grad
319 type(array_type) :: output
320
321 call output%allocate(array_shape = [ this%left_operand%shape, &
322 size(this%left_operand%val, dim=2) ])
323 call this%get_partial_left_val(upstream_grad%val, output%val)
324
325 end function get_partial_conv2d_input
326 !-------------------------------------------------------------------------------
327 function get_partial_conv2d_kernel(this, upstream_grad) result(output)
328 !! Get partial derivative wrt kernel for 2D convolution
329 implicit none
330
331 class(array_type), intent(inout) :: this
332 type(array_type), intent(in) :: upstream_grad
333 type(array_type) :: output
334
335 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
336 call this%get_partial_right_val(upstream_grad%val, output%val)
337
338 end function get_partial_conv2d_kernel
339 !-------------------------------------------------------------------------------
340 pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output)
341 !! Get partial derivative wrt input for 2D convolution (subroutine version)
342 implicit none
343
344 class(array_type), intent(in) :: this
345 real(real32), dimension(:,:), intent(in) :: upstream_grad
346 real(real32), dimension(:,:), intent(out) :: output
347
348 ! Local variables
349 integer :: i, j, ki, kj, c_in, c_out, s
350 integer :: i_in, j_in, k_idx, out_idx, in_idx
351 integer :: in_base_idx, k_base_idx, kernel_channel_size
352 integer :: input_h, input_w, kernel_h, kernel_w
353 integer :: output_h, output_w, num_channels, num_filters
354 integer, dimension(2) :: stride, dilation
355 integer :: channel_size_in, channel_size_out
356 real(real32) :: grad_val, kernel_val
357
358
359 ! Unpack parameters
360 num_channels = this%indices(1)
361 num_filters = this%indices(2)
362 stride = this%adj_ja(1:2,1)
363 dilation = this%adj_ja(1:2,2)
364 kernel_h = this%adj_ja(1,3)
365 kernel_w = this%adj_ja(2,3)
366
367 input_h = this%left_operand%shape(1)
368 input_w = this%left_operand%shape(2)
369 output_h = this%shape(1)
370 output_w = this%shape(2)
371 channel_size_in = input_h * input_w
372 channel_size_out = output_h * output_w
373 kernel_channel_size = kernel_h * kernel_w
374
375 output = 0._real32
376
377 ! Parallelised over batch, output channels, output spatial dims
378 do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, &
379 j = 1:output_w, i = 1:output_h)
380 out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out
381 grad_val = upstream_grad(out_idx, s)
382
383 if(abs(grad_val) .gt. 1.e-30_real32)then
384 do c_in = 1, num_channels
385 in_base_idx = (c_in - 1) * channel_size_in
386 k_base_idx = (c_in - 1) * kernel_channel_size + &
387 (c_out - 1) * kernel_channel_size * num_channels
388
389 do kj = 1, kernel_w
390 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
391 if(j_in .ge. 1 .and. j_in .le. input_w)then
392 do ki = 1, kernel_h
393 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
394 if(i_in .ge. 1 .and. i_in .le. input_h)then
395 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
396 k_idx = (kernel_h - ki + 1) + &
397 (kernel_w - kj) * kernel_h + k_base_idx
398 kernel_val = this%right_operand%val(k_idx, 1)
399 output(in_idx, s) = output(in_idx, s) + &
400 grad_val * kernel_val
401 end if
402 end do
403 end if
404 end do
405 end do
406 end if
407 end do
408
409 end subroutine get_partial_conv2d_input_val
410 !-------------------------------------------------------------------------------
411 pure subroutine get_partial_conv2d_kernel_val(this, upstream_grad, output)
412 !! Get partial derivative wrt kernel for 2D convolution (subroutine version)
413 implicit none
414
415 class(array_type), intent(in) :: this
416 real(real32), dimension(:,:), intent(in) :: upstream_grad
417 real(real32), dimension(:,:), intent(out) :: output
418
419 ! Local variables
420 integer :: i, j, ki, kj, c_in, c_out, s
421 integer :: i_in, j_in, k_idx, out_idx, in_idx
422 integer :: in_base_idx, out_base_idx, k_base_idx, kernel_channel_size
423 integer :: input_h, input_w, kernel_h, kernel_w
424 integer :: output_h, output_w, num_channels, num_filters
425 integer, dimension(2) :: stride, dilation
426 integer :: channel_size_in, channel_size_out
427 real(real32) :: grad_sum
428
429
430 ! Unpack parameters
431 num_channels = this%indices(1)
432 num_filters = this%indices(2)
433 stride = this%adj_ja(1:2,1)
434 dilation = this%adj_ja(1:2,2)
435 kernel_h = this%adj_ja(1,3)
436 kernel_w = this%adj_ja(2,3)
437
438 input_h = this%left_operand%shape(1)
439 input_w = this%left_operand%shape(2)
440 output_h = this%shape(1)
441 output_w = this%shape(2)
442 channel_size_in = input_h * input_w
443 channel_size_out = output_h * output_w
444 kernel_channel_size = kernel_h * kernel_w
445
446 output = 0._real32
447
448 ! Parallelised over filters, channels, and kernel dimensions
449 do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, &
450 kj = 1:kernel_w, ki = 1:kernel_h)
451 out_base_idx = (c_out - 1) * channel_size_out
452 in_base_idx = (c_in - 1) * channel_size_in
453 k_base_idx = (c_in - 1) * kernel_channel_size + &
454 (c_out - 1) * kernel_channel_size * num_channels
455 k_idx = ki + (kj - 1) * kernel_h + k_base_idx
456
457 grad_sum = 0._real32
458 do s = 1, size(upstream_grad, dim=2)
459 do j = 1, output_w
460 j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1
461 if(j_in .ge. 1 .and. j_in .le. input_w)then
462 do i = 1, output_h
463 i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1
464 if(i_in .ge. 1 .and. i_in .le. input_h)then
465 in_idx = i_in + (j_in - 1) * input_h + in_base_idx
466 out_idx = i + (j - 1) * output_h + out_base_idx
467 grad_sum = grad_sum + &
468 upstream_grad(out_idx, s) * this%left_operand%val(in_idx, s)
469 end if
470 end do
471 end if
472 end do
473 end do
474 output(k_idx, 1) = grad_sum
475 end do
476
477 end subroutine get_partial_conv2d_kernel_val
478 !###############################################################################
479
480
481 !###############################################################################
482 module function conv3d(input, kernel, stride, dilation) result(output)
483 !! 3D convolution operation
484 implicit none
485
486 ! Arguments
487 type(array_type), intent(in), target :: input
488 type(array_type), intent(in), target :: kernel
489 integer, dimension(3), intent(in) :: stride
490 integer, dimension(3), intent(in) :: dilation
491 type(array_type), pointer :: output
492
493 ! Local variables
494 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
495 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
496 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
497 integer :: output_h, output_w, output_d
498 integer :: num_channels, num_filters
499 integer :: channel_size_in, channel_size_out
500 real(real32) :: conv_sum
501 integer, dimension(5) :: output_shape
502
503 ! Extract dimensions
504 ! input: [H_in, W_in, D_in, C_in, B]
505 ! kernel: [K_h, K_w, K_d, C_in, C_out]
506 input_h = input%shape(1)
507 input_w = input%shape(2)
508 input_d = input%shape(3)
509 num_channels = input%shape(4)
510 kernel_h = kernel%shape(1)
511 kernel_w = kernel%shape(2)
512 kernel_d = kernel%shape(3)
513 num_filters = kernel%shape(5)
514
515 ! Calculate output dimensions
516 output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / &
517 stride(1) + 1
518 output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / &
519 stride(2) + 1
520 output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / &
521 stride(3) + 1
522 output_shape = [output_h, output_w, output_d, num_filters, &
523 size(input%val, dim=2)]
524
525 output => input%create_result(array_shape = output_shape)
526 output%val = 0._real32
527
528 channel_size_in = input_h * input_w * input_d
529 channel_size_out = output_h * output_w * output_d
530
531 ! Perform convolution - optimised with do concurrent
532 do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, &
533 k = 1:output_d, j = 1:output_w, i = 1:output_h)
534
535 conv_sum = 0._real32
536 do c_in = 1, num_channels
537 do kk = 1, kernel_d
538 k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1
539 if(k_in .ge. 1 .and. k_in .le. input_d)then
540 do kj = 1, kernel_w
541 j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1
542 if(j_in .ge. 1 .and. j_in .le. input_w)then
543 do ki = 1, kernel_h
544 i_in = ( i - 1 ) * stride(1) + &
545 ( ki - 1 ) * dilation(1) + 1
546 if(i_in .ge. 1 .and. i_in .le. input_h)then
547 in_idx = i_in + ( j_in - 1 ) * input_h + &
548 ( k_in - 1 ) * input_h * input_w + &
549 ( c_in - 1 ) * channel_size_in
550 k_idx = ki + ( kj - 1 ) * kernel_h + &
551 ( kk - 1 ) * kernel_h * kernel_w + &
552 ( c_in - 1 ) * kernel_h * kernel_w * &
553 kernel_d + &
554 ( c_out - 1 ) * kernel_h * kernel_w * &
555 kernel_d * num_channels
556 conv_sum = conv_sum + input%val(in_idx, s) * &
557 kernel%val(k_idx, 1)
558 end if
559 end do
560 end if
561 end do
562 end if
563 end do
564 end do
565 out_idx = i + ( j - 1 ) * output_h + &
566 ( k - 1 ) * output_h * output_w + &
567 ( c_out - 1 ) * channel_size_out
568 output%val(out_idx, s) = conv_sum
569 end do
570
571 ! Store parameters for backward pass
572 allocate(output%indices(2))
573 output%indices(1) = num_channels
574 output%indices(2) = num_filters
575 allocate(output%adj_ja(3,3))
576 output%adj_ja(1:3,1) = stride
577 output%adj_ja(1:3,2) = dilation
578 output%adj_ja(1,3) = kernel_h
579 output%adj_ja(2,3) = kernel_w
580 output%adj_ja(3,3) = kernel_d
581
582 output%get_partial_left => get_partial_conv3d_input
583 output%get_partial_right => get_partial_conv3d_kernel
584 output%get_partial_left_val => get_partial_conv3d_input_val
585 output%get_partial_right_val => get_partial_conv3d_kernel_val
586 if(input%requires_grad .or. kernel%requires_grad)then
587 output%requires_grad = .true.
588 output%is_forward = input%is_forward
589 output%operation = 'conv3d'
590 output%left_operand => input
591 output%right_operand => kernel
592 end if
593
594 end function conv3d
595 !-------------------------------------------------------------------------------
596 function get_partial_conv3d_input(this, upstream_grad) result(output)
597 !! Get partial derivative wrt input for 3D convolution
598 implicit none
599
600 class(array_type), intent(inout) :: this
601 type(array_type), intent(in) :: upstream_grad
602 type(array_type) :: output
603
604 call output%allocate(array_shape = [ this%left_operand%shape, &
605 size(this%left_operand%val, dim=2) ])
606 call this%get_partial_left_val(upstream_grad%val, output%val)
607
608 end function get_partial_conv3d_input
609 !-------------------------------------------------------------------------------
610 function get_partial_conv3d_kernel(this, upstream_grad) result(output)
611 !! Get partial derivative wrt kernel for 3D convolution
612 implicit none
613
614 class(array_type), intent(inout) :: this
615 type(array_type), intent(in) :: upstream_grad
616 type(array_type) :: output
617
618 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
619 call this%get_partial_right_val(upstream_grad%val, output%val)
620
621 end function get_partial_conv3d_kernel
622 !-------------------------------------------------------------------------------
623 pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output)
624 !! Get partial derivative wrt input for 3D convolution (subroutine version)
625 implicit none
626
627 class(array_type), intent(in) :: this
628 real(real32), dimension(:,:), intent(in) :: upstream_grad
629 real(real32), dimension(:,:), intent(out) :: output
630
631 ! Local variables
632 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
633 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
634 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
635 integer :: output_h, output_w, output_d
636 integer :: num_channels, num_filters
637 integer :: channel_size_in, channel_size_out
638 integer, dimension(3) :: stride, dilation
639 real(real32) :: grad_val, kernel_val
640
641
642 ! Unpack parameters
643 num_channels = this%indices(1)
644 num_filters = this%indices(2)
645 stride = this%adj_ja(1:3,1)
646 dilation = this%adj_ja(1:3,2)
647 kernel_h = this%adj_ja(1,3)
648 kernel_w = this%adj_ja(2,3)
649 kernel_d = this%adj_ja(3,3)
650
651 input_h = this%left_operand%shape(1)
652 input_w = this%left_operand%shape(2)
653 input_d = this%left_operand%shape(3)
654 output_h = this%shape(1)
655 output_w = this%shape(2)
656 output_d = this%shape(3)
657
658 output = 0._real32
659
660 channel_size_in = input_h * input_w * input_d
661 channel_size_out = output_h * output_w * output_d
662
663 do s = 1, size(upstream_grad, dim=2)
664 do c_in = 1, num_channels
665 do k = 1, output_d
666 do j = 1, output_w
667 do i = 1, output_h
668 do c_out = 1, num_filters
669 out_idx = i + ( j - 1 ) * output_h + &
670 ( k - 1 ) * output_h * output_w + &
671 ( c_out - 1 ) * channel_size_out
672 grad_val = upstream_grad(out_idx, s)
673
674 do kk = 1, kernel_d
675 k_in = ( k - 1 ) * stride(3) + &
676 ( kk - 1 ) * dilation(3) + 1
677 if( k_in .ge. 1 .and. k_in .le. input_d )then
678 do kj = 1, kernel_w
679 j_in = ( j - 1 ) * stride(2) + &
680 ( kj - 1 ) * dilation(2) + 1
681 if( j_in .ge. 1 .and. j_in .le. input_w )then
682 do ki = 1, kernel_h
683 i_in = ( i - 1 ) * stride(1) + &
684 ( ki - 1 ) * dilation(1) + 1
685 if( i_in .ge. 1 .and. &
686 i_in .le. input_h )then
687 in_idx = i_in + &
688 ( j_in - 1 ) * input_h + &
689 ( k_in - 1 ) * input_h * input_w + &
690 ( c_in - 1 ) * channel_size_in
691 k_idx = ki + ( kj - 1 ) * kernel_h + &
692 ( kk - 1 ) * kernel_h * kernel_w + &
693 ( c_in - 1 ) * kernel_h * &
694 kernel_w * kernel_d + &
695 ( c_out - 1 ) * kernel_h * &
696 kernel_w * kernel_d * num_channels
697 kernel_val = this%right_operand%val(k_idx, 1)
698 output(in_idx, s) = &
699 output(in_idx, s) + &
700 grad_val * kernel_val
701 end if
702 end do
703 end if
704 end do
705 end if
706 end do
707 end do
708 end do
709 end do
710 end do
711 end do
712 end do
713
714 end subroutine get_partial_conv3d_input_val
715 !-------------------------------------------------------------------------------
716 pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output)
717 !! Get partial derivative wrt kernel for 3D convolution (subroutine version)
718 implicit none
719
720 class(array_type), intent(in) :: this
721 real(real32), dimension(:,:), intent(in) :: upstream_grad
722 real(real32), dimension(:,:), intent(out) :: output
723
724 ! Local variables
725 integer :: i, j, k, ki, kj, kk, c_in, c_out, s
726 integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx
727 integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d
728 integer :: output_h, output_w, output_d
729 integer :: num_channels, num_filters
730 integer :: channel_size_in, channel_size_out
731 integer, dimension(3) :: stride, dilation
732 real(real32) :: grad_sum
733
734
735 ! Unpack parameters
736 num_channels = this%indices(1)
737 num_filters = this%indices(2)
738 stride = this%adj_ja(1:3,1)
739 dilation = this%adj_ja(1:3,2)
740 kernel_h = this%adj_ja(1,3)
741 kernel_w = this%adj_ja(2,3)
742 kernel_d = this%adj_ja(3,3)
743
744 input_h = this%left_operand%shape(1)
745 input_w = this%left_operand%shape(2)
746 input_d = this%left_operand%shape(3)
747 output_h = this%shape(1)
748 output_w = this%shape(2)
749 output_d = this%shape(3)
750
751 output = 0._real32
752
753 channel_size_in = input_h * input_w * input_d
754 channel_size_out = output_h * output_w * output_d
755
756 do c_out = 1, num_filters
757 do c_in = 1, num_channels
758 do kk = 1, kernel_d
759 do kj = 1, kernel_w
760 do ki = 1, kernel_h
761 k_idx = ki + (kj-1)*kernel_h + &
762 (kk-1)*kernel_h*kernel_w + &
763 (c_in-1)*kernel_h*kernel_w*kernel_d + &
764 (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels
765
766 grad_sum = 0._real32
767 do s = 1, size(upstream_grad, dim=2)
768 do k = 1, output_d
769 k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1
770 if(k_in >= 1 .and. k_in <= input_d)then
771 do j = 1, output_w
772 j_in = (j-1)*stride(2) + &
773 (kj-1)*dilation(2) + 1
774 if(j_in >= 1 .and. j_in <= input_w)then
775 do i = 1, output_h
776 i_in = (i-1)*stride(1) + &
777 (ki-1)*dilation(1) + 1
778 if(i_in >= 1 .and. i_in <= input_h)then
779 in_idx = i_in + (j_in-1)*input_h + &
780 (k_in-1)*input_h*input_w + &
781 (c_in-1)*channel_size_in
782 out_idx = i + (j-1)*output_h + &
783 (k-1)*output_h*output_w + &
784 (c_out-1)*channel_size_out
785 grad_sum = grad_sum + &
786 upstream_grad(out_idx, s) * &
787 this%left_operand%val(in_idx, s)
788 end if
789 end do
790 end if
791 end do
792 end if
793 end do
794 end do
795 output(k_idx, 1) = grad_sum
796 end do
797 end do
798 end do
799 end do
800 end do
801
802 end subroutine get_partial_conv3d_kernel_val
803 !###############################################################################
804
805 end submodule athena__diffstruc_extd_submodule_conv
806