GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_pool.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_pool
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function avgpool1d(input, pool_size, stride) result(output)
8 !! 1D average pooling operation
9 implicit none
10
11 ! Arguments
12 type(array_type), intent(in), target :: input
13 integer, intent(in) :: pool_size
14 integer, intent(in) :: stride
15 type(array_type), pointer :: output
16
17 ! Local variables
18 integer :: i, m, s
19 integer :: stride_idx, idx
20 integer, dimension(3) :: output_shape
21
22 output_shape = [ &
23 (input%shape(1) - pool_size) / stride + 1, &
24 input%shape(2), &
25 size(input%val, dim=2)]
26 output => input%create_result(array_shape = output_shape)
27 do concurrent(&
28 s = 1:output_shape(3), &
29 m = 1:output_shape(2), &
30 i = 1:output_shape(1))
31 stride_idx = (i - 1) * stride + (m - 1) * input%shape(1)
32 idx = i + (m - 1) * output_shape(1)
33 output%val(idx, s) = sum( &
34 input%val( stride_idx + 1 : stride_idx + pool_size, s ) &
35 ) / pool_size
36 end do
37 allocate(output%adj_ja(1,2))
38 output%adj_ja(1,1) = pool_size
39 output%adj_ja(1,2) = stride
40
41 output%get_partial_left => get_partial_avgpool1d
42 output%get_partial_left_val => get_partial_avgpool1d_val
43 if(input%requires_grad)then
44 output%requires_grad = .true.
45 output%is_forward = input%is_forward
46 output%operation = 'avgpool'
47 output%left_operand => input
48 end if
49
50 end function avgpool1d
51 !-------------------------------------------------------------------------------
52 function get_partial_avgpool1d(this, upstream_grad) result(output)
53 !! Get the partial derivative for average pooling
54 implicit none
55
56 ! Arguments
57 class(array_type), intent(inout) :: this
58 type(array_type), intent(in) :: upstream_grad
59 type(array_type) :: output
60
61 call output%allocate(array_shape = &
62 [ this%left_operand%shape, size(this%val, dim=2) ] &
63 )
64 call this%get_partial_left_val(upstream_grad%val, output%val)
65
66 end function get_partial_avgpool1d
67 !-------------------------------------------------------------------------------
68 pure subroutine get_partial_avgpool1d_val(this, upstream_grad, output)
69 !! Optimised backward pass for 1D average pooling
70 implicit none
71
72 ! Arguments
73 class(array_type), intent(in) :: this
74 real(real32), dimension(:,:), intent(in) :: upstream_grad
75 real(real32), dimension(:,:), intent(out) :: output
76
77 ! Local variables
78 integer :: i, m, s, p
79 integer :: base_idx, out_idx, input_h
80 real(real32) :: pool_norm, grad_val
81 integer, dimension(3) :: input_shape
82 integer, dimension(1) :: pool_size, stride
83
84 ! Unpack parameters
85 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
86 pool_size(1) = this%adj_ja(1,1)
87 stride(1) = this%adj_ja(1,2)
88 input_h = input_shape(1)
89
90 output = 0._real32
91
92 pool_norm = 1.0_real32 / real(pool_size(1), real32)
93
94 ! Parallelised over batch and spatial/channel dimensions
95 do concurrent(s = 1:input_shape(3), m = 1:this%shape(2), &
96 i = 1:this%shape(1))
97
98 ! Compute indices once
99 base_idx = (i - 1) * stride(1) + (m - 1) * input_h
100 out_idx = i + (m - 1) * this%shape(1)
101 grad_val = upstream_grad(out_idx, s) * pool_norm
102
103 ! Distribute gradient over pooling window
104 do p = 0, pool_size(1) - 1
105 output(base_idx + p + 1, s) = output(base_idx + p + 1, s) + grad_val
106 end do
107 end do
108
109 end subroutine get_partial_avgpool1d_val
110 !###############################################################################
111
112
113 !###############################################################################
114 module function avgpool2d(input, pool_size, stride) result(output)
115 !! 2D average pooling operation
116 implicit none
117
118 ! Arguments
119 type(array_type), intent(in), target :: input
120 integer, dimension(2), intent(in) :: pool_size
121 integer, dimension(2), intent(in) :: stride
122 type(array_type), pointer :: output
123
124 ! Local variables
125 integer :: i, j, m, s, i_step, j_step
126 integer :: stride_idx, idx, multiplier
127 integer :: channel_size_in, channel_size_out
128 real(real32) :: pool_sum, pool_norm
129 integer, dimension(4) :: output_shape
130
131 output_shape = [ &
132 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
133 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
134 input%shape(3), &
135 size(input%val, dim=2)]
136 output => input%create_result(array_shape = output_shape)
137 pool_norm = 1.0_real32 / real(pool_size(1) * pool_size(2), real32)
138
139 ! Pre-compute as integers
140 channel_size_in = input%shape(1) * input%shape(2)
141 channel_size_out = output_shape(1) * output_shape(2)
142
143 do concurrent(&
144 s = 1:output_shape(4), &
145 m = 1:output_shape(3), &
146 j = 1:output_shape(2), &
147 i = 1:output_shape(1))
148
149 ! Compute indices once
150 stride_idx = (i-1)*stride(1) + &
151 ((j-1)*stride(2)) * input%shape(1) + &
152 (m-1) * channel_size_in
153 idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out
154
155 pool_sum = 0._real32
156 do j_step = 0, pool_size(2)-1
157 do i_step = 0, pool_size(1)-1
158 pool_sum = pool_sum + &
159 input%val(stride_idx + i_step + j_step * input%shape(1) + 1, s)
160 end do
161 end do
162 output%val(idx, s) = pool_sum * pool_norm
163 end do
164 allocate(output%adj_ja(2,2))
165 output%adj_ja(:,1) = pool_size
166 output%adj_ja(:,2) = stride
167
168 output%get_partial_left => get_partial_avgpool2d
169 output%get_partial_left_val => get_partial_avgpool2d_val
170 if(input%requires_grad)then
171 output%requires_grad = .true.
172 output%is_forward = input%is_forward
173 output%operation = 'avgpool'
174 output%left_operand => input
175 end if
176
177 end function avgpool2d
178 !-------------------------------------------------------------------------------
179 function get_partial_avgpool2d(this, upstream_grad) result(output)
180 !! Get the partial derivative for average pooling
181 implicit none
182
183 ! Arguments
184 class(array_type), intent(inout) :: this
185 type(array_type), intent(in) :: upstream_grad
186 type(array_type) :: output
187
188 call output%allocate(array_shape = &
189 [ this%left_operand%shape, size(this%val, dim=2) ] &
190 )
191 call this%get_partial_left_val(upstream_grad%val, output%val)
192
193 end function get_partial_avgpool2d
194 !-------------------------------------------------------------------------------
195 pure subroutine get_partial_avgpool2d_val(this, upstream_grad, output)
196 !! Optimised backward pass for 2D average pooling
197 implicit none
198
199 ! Arguments
200 class(array_type), intent(in) :: this
201 real(real32), dimension(:,:), intent(in) :: upstream_grad
202 real(real32), dimension(:,:), intent(out) :: output
203
204 ! Local variables
205 integer :: i, j, m, s
206 integer :: i_step, j_step
207 integer :: base_idx, in_idx, out_idx, input_h
208 integer :: channel_size_in, channel_size_out
209 real(real32) :: pool_norm, grad_val
210 integer, dimension(4) :: input_shape
211 integer, dimension(2) :: pool_size, stride
212
213 ! Unpack parameters
214 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
215 pool_size = this%adj_ja(:,1)
216 stride = this%adj_ja(:,2)
217 input_h = input_shape(1)
218 channel_size_in = input_h * input_shape(2)
219 channel_size_out = this%shape(1) * this%shape(2)
220
221 output = 0._real32
222
223 pool_norm = 1.0_real32 / real(pool_size(1) * pool_size(2), real32)
224
225 do concurrent( &
226 s = 1:input_shape(4), &
227 m = 1:this%shape(3), &
228 j = 1:this%shape(2), &
229 i = 1:this%shape(1))
230
231 ! Compute indices once
232 base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
233 (m-1) * channel_size_in
234 out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
235 grad_val = upstream_grad(out_idx, s) * pool_norm
236
237 ! Distribute gradient over pooling window
238 do j_step = 0, pool_size(2) - 1
239 do i_step = 0, pool_size(1) - 1
240 in_idx = base_idx + i_step + j_step * input_h + 1
241 output(in_idx, s) = output(in_idx, s) + grad_val
242 end do
243 end do
244 end do
245
246 end subroutine get_partial_avgpool2d_val
247 !###############################################################################
248
249
250 !###############################################################################
251 module function avgpool3d(input, pool_size, stride) result(output)
252 !! 3D average pooling operation
253 implicit none
254
255 ! Arguments
256 type(array_type), intent(in), target :: input
257 integer, dimension(3), intent(in) :: pool_size
258 integer, dimension(3), intent(in) :: stride
259 type(array_type), pointer :: output
260
261 ! Local variables
262 integer :: i, j, k, m, s
263 integer :: i_step, j_step, k_step
264 integer :: stride_idx, idx
265 integer :: channel_size_in, channel_size_out
266 real(real32) :: pool_sum, pool_norm
267 integer, dimension(5) :: output_shape
268
269 ! output_shape = [H_out, W_out, D_out, C, B]
270 output_shape = [ &
271 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
272 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
273 (input%shape(3) - pool_size(3)) / stride(3) + 1, &
274 input%shape(4), &
275 size(input%val, dim=2) ]
276
277 output => input%create_result(array_shape = output_shape)
278 pool_norm = 1.0_real32 / real(product(pool_size), real32)
279
280 ! Pre-compute as integers
281 channel_size_in = input%shape(1) * input%shape(2) * input%shape(3)
282 channel_size_out = output_shape(1) * output_shape(2) * output_shape(3)
283
284 do concurrent( &
285 s = 1:output_shape(5), &
286 m = 1:output_shape(4), &
287 k = 1:output_shape(3), &
288 j = 1:output_shape(2), &
289 i = 1:output_shape(1))
290
291 ! Compute indices once
292 stride_idx = ((i-1)*stride(1)) + &
293 ((j-1)*stride(2)) * input%shape(1) + &
294 ((k-1)*stride(3)) * input%shape(1) * input%shape(2) + &
295 (m-1) * channel_size_in
296 idx = i + (j-1) * output_shape(1) + &
297 (k-1) * output_shape(1)*output_shape(2) + &
298 (m-1) * channel_size_out
299
300 pool_sum = 0._real32
301 do k_step = 0, pool_size(3)-1
302 do j_step = 0, pool_size(2)-1
303 do i_step = 0, pool_size(1)-1
304 pool_sum = pool_sum + input%val(stride_idx + i_step + &
305 j_step * input%shape(1) + &
306 k_step * input%shape(1) * input%shape(2) + 1, s)
307 end do
308 end do
309 end do
310
311 output%val(idx, s) = pool_sum * pool_norm
312 end do
313
314 allocate(output%adj_ja(3,2))
315 output%adj_ja(:,1) = pool_size
316 output%adj_ja(:,2) = stride
317
318 output%get_partial_left => get_partial_avgpool3d
319 output%get_partial_left_val => get_partial_avgpool3d_val
320 if (input%requires_grad) then
321 output%requires_grad = .true.
322 output%is_forward = input%is_forward
323 output%operation = 'avgpool3d'
324 output%left_operand => input
325 end if
326
327 end function avgpool3d
328 !-------------------------------------------------------------------------------
329 function get_partial_avgpool3d(this, upstream_grad) result(output)
330 !! Get the partial derivative for 3D average pooling
331 implicit none
332
333 ! Arguments
334 class(array_type), intent(inout) :: this
335 type(array_type), intent(in) :: upstream_grad
336 type(array_type) :: output
337
338 call output%allocate(array_shape = &
339 [ this%left_operand%shape, size(this%val, dim=2) ] &
340 )
341 call this%get_partial_left_val(upstream_grad%val, output%val)
342
343 end function get_partial_avgpool3d
344 !-------------------------------------------------------------------------------
345 pure subroutine get_partial_avgpool3d_val(this, upstream_grad, output)
346 !! Optimised backward pass for 3D average pooling
347 implicit none
348
349 ! Arguments
350 class(array_type), intent(in) :: this
351 real(real32), dimension(:,:), intent(in) :: upstream_grad
352 real(real32), dimension(:,:), intent(out) :: output
353
354 ! Local variables
355 integer :: i, j, k, m, s
356 integer :: i_step, j_step, k_step
357 integer :: base_idx, in_idx, out_idx, input_h, input_hw
358 integer :: channel_size_in, channel_size_out
359 real(real32) :: pool_norm, grad_val
360 integer, dimension(5) :: input_shape
361 integer, dimension(3) :: pool_size, stride
362
363 ! Unpack parameters
364 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
365 pool_size = this%adj_ja(:,1)
366 stride = this%adj_ja(:,2)
367 input_h = input_shape(1)
368 input_hw = input_h * input_shape(2)
369 channel_size_in = input_hw * input_shape(3)
370 channel_size_out = this%shape(1) * this%shape(2) * this%shape(3)
371
372 output = 0._real32
373
374 pool_norm = 1.0_real32 / real(product(pool_size), real32)
375
376 do concurrent( &
377 s = 1:input_shape(5), &
378 m = 1:this%shape(4), &
379 k = 1:this%shape(3), &
380 j = 1:this%shape(2), &
381 i = 1:this%shape(1))
382
383 ! Compute indices once
384 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
385 ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in
386 out_idx = i + (j-1) * this%shape(1) + &
387 (k-1) * this%shape(1)*this%shape(2) + &
388 (m-1) * channel_size_out
389 grad_val = upstream_grad(out_idx, s) * pool_norm
390
391 ! Distribute gradient over pooling window
392 do k_step = 0, pool_size(3)-1
393 do j_step = 0, pool_size(2)-1
394 do i_step = 0, pool_size(1)-1
395 in_idx = base_idx + i_step + j_step * input_h + &
396 k_step * input_hw + 1
397 output(in_idx, s) = output(in_idx, s) + grad_val
398 end do
399 end do
400 end do
401 end do
402
403 end subroutine get_partial_avgpool3d_val
404 !###############################################################################
405
406
407 !##############################################################################!
408 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
409 !##############################################################################!
410
411
412 !###############################################################################
413 module function maxpool1d(input, pool_size, stride) result(output)
414 !! 1D max pooling operation
415 implicit none
416
417 ! Arguments
418 type(array_type), intent(in), target :: input
419 integer, intent(in) :: pool_size
420 integer, intent(in) :: stride
421 type(array_type), pointer :: output
422
423 ! Local variables
424 integer :: i, m, s
425 integer :: stride_idx, idx
426 integer, dimension(3) :: output_shape
427
428 output_shape = [ &
429 (input%shape(1) - pool_size) / stride + 1, &
430 input%shape(2), &
431 size(input%val, dim=2)]
432 output => input%create_result(array_shape = output_shape)
433 do concurrent(&
434 s = 1:output_shape(3), &
435 m = 1:output_shape(2), &
436 i = 1:output_shape(1))
437 stride_idx = (i - 1) * stride + (m - 1) * input%shape(1)
438 idx = i + (m - 1) * output_shape(1)
439 output%val(idx, s) = maxval( &
440 input%val( stride_idx + 1 : stride_idx + pool_size, s ) &
441 )
442 end do
443 allocate(output%adj_ja(1,2))
444 output%adj_ja(1,1) = pool_size
445 output%adj_ja(1,2) = stride
446
447 output%get_partial_left => get_partial_maxpool1d
448 output%get_partial_left_val => get_partial_maxpool1d_val
449 if(input%requires_grad)then
450 output%requires_grad = .true.
451 output%is_forward = input%is_forward
452 output%operation = 'maxpool'
453 output%left_operand => input
454 end if
455
456 end function maxpool1d
457 !-------------------------------------------------------------------------------
458 function get_partial_maxpool1d(this, upstream_grad) result(output)
459 !! Get the partial derivative for max pooling
460 implicit none
461
462 ! Arguments
463 class(array_type), intent(inout) :: this
464 type(array_type), intent(in) :: upstream_grad
465 type(array_type) :: output
466
467 call output%allocate(array_shape = &
468 [ this%left_operand%shape, size(this%val, dim=2) ] &
469 )
470 call this%get_partial_left_val(upstream_grad%val, output%val)
471
472 end function get_partial_maxpool1d
473 !-------------------------------------------------------------------------------
474 pure subroutine get_partial_maxpool1d_val(this, upstream_grad, output)
475 !! Optimised backward pass for 1D max pooling
476 implicit none
477
478 ! Arguments
479 class(array_type), intent(in) :: this
480 real(real32), dimension(:,:), intent(in) :: upstream_grad
481 real(real32), dimension(:,:), intent(out) :: output
482
483 ! Local variables
484 integer :: i, m, s, p
485 integer :: base_idx, max_idx, out_idx, input_h
486 real(real32) :: pool_max, grad_val
487 integer, dimension(3) :: input_shape
488 integer, dimension(1) :: pool_size, stride
489
490 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
491 pool_size(1) = this%adj_ja(1,1)
492 stride(1) = this%adj_ja(1,2)
493 input_h = input_shape(1)
494
495 output = 0._real32
496
497 do concurrent(s = 1:input_shape(3), m = 1:this%shape(2), &
498 i = 1:this%shape(1))
499
500 ! Compute indices once
501 base_idx = (i - 1) * stride(1) + (m - 1) * input_h
502 out_idx = i + (m - 1) * this%shape(1)
503 grad_val = upstream_grad(out_idx, s)
504
505 ! Find max value location - initialise with first element
506 max_idx = base_idx + 1
507 pool_max = this%left_operand%val(max_idx, s)
508
509 ! Search remaining elements for max
510 do p = 1, pool_size(1) - 1
511 if(this%left_operand%val(base_idx + p + 1, s) > pool_max) then
512 pool_max = this%left_operand%val(base_idx + p + 1, s)
513 max_idx = base_idx + p + 1
514 end if
515 end do
516
517 ! Assign gradient to max location
518 output(max_idx, s) = output(max_idx, s) + grad_val
519 end do
520
521 end subroutine get_partial_maxpool1d_val
522 !###############################################################################
523
524
525 !###############################################################################
526 module function maxpool2d(input, pool_size, stride) result(output)
527 !! 2D max pooling operation
528 implicit none
529
530 ! Arguments
531 type(array_type), intent(in), target :: input
532 integer, dimension(2), intent(in) :: pool_size
533 integer, dimension(2), intent(in) :: stride
534 type(array_type), pointer :: output
535
536 ! Local variables
537 integer :: i, j, m, s, i_step, j_step
538 integer :: base_idx, stride_idx, idx, input_h
539 real(real32) :: pool_max, val_tmp
540 integer :: channel_size_in, channel_size_out
541 integer, dimension(4) :: output_shape
542
543 output_shape = [ &
544 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
545 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
546 input%shape(3), &
547 size(input%val, dim=2)]
548 output => input%create_result(array_shape = output_shape)
549
550 ! Pre-compute as integers to avoid type conversion in loop
551 input_h = input%shape(1)
552 channel_size_in = input_h * input%shape(2)
553 channel_size_out = output_shape(1) * output_shape(2)
554
555 do concurrent(&
556 s = 1:output_shape(4), &
557 m = 1:output_shape(3), &
558 j = 1:output_shape(2), &
559 i = 1:output_shape(1))
560
561 ! Compute indices once per output position
562 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
563 (m-1) * channel_size_in
564 idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out
565
566 ! Find max value - initialise with first element for better performance
567 stride_idx = base_idx + 1
568 pool_max = input%val(stride_idx, s)
569
570 ! Continue with remaining elements
571 do j_step = 0, pool_size(2)-1
572 do i_step = 0, pool_size(1)-1
573 if(i_step .eq. 0 .and. j_step .eq. 0) cycle ! Already processed
574 stride_idx = base_idx + i_step + j_step * input_h + 1
575 if(input%val(stride_idx, s) > pool_max) &
576 pool_max = input%val(stride_idx, s)
577 end do
578 end do
579
580 output%val(idx, s) = pool_max
581 end do
582
583 allocate(output%adj_ja(2,2))
584 output%adj_ja(:,1) = pool_size
585 output%adj_ja(:,2) = stride
586
587 output%get_partial_left => get_partial_maxpool2d
588 output%get_partial_left_val => get_partial_maxpool2d_val
589 if(input%requires_grad)then
590 output%requires_grad = .true.
591 output%is_forward = input%is_forward
592 output%operation = 'maxpool'
593 output%left_operand => input
594 end if
595
596 end function maxpool2d
597 !-------------------------------------------------------------------------------
598 function get_partial_maxpool2d(this, upstream_grad) result(output)
599 !! Get the partial derivative for max pooling
600 implicit none
601
602 ! Arguments
603 class(array_type), intent(inout) :: this
604 type(array_type), intent(in) :: upstream_grad
605 type(array_type) :: output
606
607 call output%allocate(array_shape = &
608 [ this%left_operand%shape, size(this%val, dim=2) ] &
609 )
610 call this%get_partial_left_val(upstream_grad%val, output%val)
611
612 end function get_partial_maxpool2d
613 !-------------------------------------------------------------------------------
614 pure subroutine get_partial_maxpool2d_val(this, upstream_grad, output)
615 implicit none
616
617 ! Arguments
618 class(array_type), intent(in) :: this
619 real(real32), dimension(:,:), intent(in) :: upstream_grad
620 real(real32), dimension(:,:), intent(out) :: output
621
622 ! Local variables
623 integer :: i, j, m, s
624 integer :: i_step, j_step
625 integer :: base_idx, in_idx, out_idx, max_idx, input_h
626 real(real32) :: pool_max, val_tmp, grad_val
627 integer :: channel_size_in, channel_size_out
628 integer, dimension(4) :: input_shape
629 integer, dimension(2) :: pool_size, stride
630
631 ! Unpack parameters
632 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
633 pool_size = this%adj_ja(:,1)
634 stride = this%adj_ja(:,2)
635 input_h = input_shape(1)
636 channel_size_in = input_h * input_shape(2)
637 channel_size_out = this%shape(1) * this%shape(2)
638
639 output = 0._real32
640
641 ! Parallelised over batch and spatial/channel dimensions
642 do concurrent(s = 1:input_shape(4), m = 1:this%shape(3), &
643 j = 1:this%shape(2), i = 1:this%shape(1))
644
645 ! Compute indices once
646 base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
647 (m-1) * channel_size_in
648 out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
649 grad_val = upstream_grad(out_idx, s)
650
651 ! Find max value location - initialise with first element
652 max_idx = base_idx + 1
653 pool_max = this%left_operand%val(max_idx, s)
654
655 ! Search remaining elements for max
656 do j_step = 0, pool_size(2) - 1
657 do i_step = 0, pool_size(1) - 1
658 if(i_step == 0 .and. j_step == 0) cycle ! Already processed
659 in_idx = base_idx + i_step + j_step * input_h + 1
660 val_tmp = this%left_operand%val(in_idx, s)
661
662 if (val_tmp .gt. pool_max) then
663 pool_max = val_tmp
664 max_idx = in_idx
665 end if
666 end do
667 end do
668
669 ! Assign gradient to max location
670 output(max_idx, s) = output(max_idx, s) + grad_val
671 end do
672
673 end subroutine get_partial_maxpool2d_val
674 !###############################################################################
675
676
677 !###############################################################################
678 module function maxpool3d(input, pool_size, stride) result(output)
679 !! 3D max pooling operation
680 implicit none
681
682 ! Arguments
683 type(array_type), intent(in), target :: input
684 integer, dimension(3), intent(in) :: pool_size
685 integer, dimension(3), intent(in) :: stride
686 type(array_type), pointer :: output
687
688 ! Local variables
689 integer :: i, j, k, m, s
690 integer :: i_step, j_step, k_step
691 integer :: stride_idx, idx
692 integer :: channel_size_in, channel_size_out
693 real(real32) :: pool_max
694 integer, dimension(5) :: output_shape
695
696 ! output_shape = [H_out, W_out, D_out, C, B]
697 output_shape = [ &
698 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
699 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
700 (input%shape(3) - pool_size(3)) / stride(3) + 1, &
701 input%shape(4), &
702 size(input%val, dim=2) ]
703
704 output => input%create_result(array_shape = output_shape)
705
706 ! Pre-compute as integers
707 channel_size_in = input%shape(1) * input%shape(2) * input%shape(3)
708 channel_size_out = output_shape(1) * output_shape(2) * output_shape(3)
709
710 do concurrent( &
711 s = 1:output_shape(5), &
712 m = 1:output_shape(4), &
713 k = 1:output_shape(3), &
714 j = 1:output_shape(2), &
715 i = 1:output_shape(1))
716
717 ! Compute indices once per output position
718 stride_idx = ((i-1)*stride(1)) + &
719 ((j-1)*stride(2)) * input%shape(1) + &
720 ((k-1)*stride(3)) * input%shape(1) * input%shape(2) + &
721 (m-1) * channel_size_in + 1
722 idx = i + (j-1) * output_shape(1) + &
723 (k-1) * output_shape(1)*output_shape(2) + &
724 (m-1) * channel_size_out
725
726 ! Find max value - initialise with first element
727 pool_max = input%val(stride_idx, s)
728
729 do k_step = 0, pool_size(3)-1
730 do j_step = 0, pool_size(2)-1
731 do i_step = 0, pool_size(1)-1
732 if(i_step == 0 .and. j_step == 0 .and. k_step == 0) cycle
733 if( &
734 input%val( &
735 stride_idx + i_step + &
736 j_step * input%shape(1) + &
737 k_step * input%shape(1) * input%shape(2), s &
738 ) .gt. pool_max &
739 )then
740 pool_max = input%val(stride_idx + i_step + &
741 j_step * input%shape(1) + &
742 k_step * input%shape(1) * input%shape(2), s)
743 end if
744 end do
745 end do
746 end do
747
748 output%val(idx, s) = pool_max
749 end do
750
751 allocate(output%adj_ja(3,2))
752 output%adj_ja(:,1) = pool_size
753 output%adj_ja(:,2) = stride
754
755 output%get_partial_left => get_partial_maxpool3d
756 output%get_partial_left_val => get_partial_maxpool3d_val
757 if (input%requires_grad) then
758 output%requires_grad = .true.
759 output%is_forward = input%is_forward
760 output%operation = 'maxpool3d'
761 output%left_operand => input
762 end if
763
764 end function maxpool3d
765 !-------------------------------------------------------------------------------
766 function get_partial_maxpool3d(this, upstream_grad) result(output)
767 !! Get the partial derivative for 3D max pooling
768 implicit none
769
770 ! Arguments
771 class(array_type), intent(inout) :: this
772 type(array_type), intent(in) :: upstream_grad
773 type(array_type) :: output
774
775 call output%allocate(array_shape = &
776 [ this%left_operand%shape, size(this%val, dim=2) ] &
777 )
778 call this%get_partial_left_val(upstream_grad%val, output%val)
779
780 end function get_partial_maxpool3d
781 !-------------------------------------------------------------------------------
782 pure subroutine get_partial_maxpool3d_val(this, upstream_grad, output)
783 !! Optimised backward pass for 3D max pooling
784 implicit none
785
786 ! Arguments
787 class(array_type), intent(in) :: this
788 real(real32), dimension(:,:), intent(in) :: upstream_grad
789 real(real32), dimension(:,:), intent(out) :: output
790
791 ! Local variables
792 integer :: i, j, k, m, s
793 integer :: i_step, j_step, k_step
794 integer :: base_idx, in_idx, out_idx, max_idx
795 integer :: input_h, input_hw
796 integer :: channel_size_in, channel_size_out
797 real(real32) :: pool_max, val_tmp, grad_val
798 integer, dimension(5) :: input_shape
799 integer, dimension(3) :: pool_size, stride
800
801 ! Unpack parameters
802 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
803 pool_size = this%adj_ja(:,1)
804 stride = this%adj_ja(:,2)
805 input_h = input_shape(1)
806 input_hw = input_h * input_shape(2)
807 channel_size_in = input_hw * input_shape(3)
808 channel_size_out = this%shape(1) * this%shape(2) * this%shape(3)
809
810 output = 0._real32
811
812 ! Parallelised over batch and spatial/channel dimensions
813 do concurrent(s = 1:input_shape(5), m = 1:this%shape(4), &
814 k = 1:this%shape(3), j = 1:this%shape(2), i = 1:this%shape(1))
815
816 ! Compute indices once
817 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
818 ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in
819 out_idx = i + (j-1) * this%shape(1) + &
820 (k-1) * this%shape(1)*this%shape(2) + &
821 (m-1) * channel_size_out
822 grad_val = upstream_grad(out_idx, s)
823
824 ! Find max value location - initialise with first element
825 max_idx = base_idx + 1
826 pool_max = this%left_operand%val(max_idx, s)
827
828 ! Search remaining elements for max
829 do k_step = 0, pool_size(3)-1
830 do j_step = 0, pool_size(2)-1
831 do i_step = 0, pool_size(1)-1
832 if(i_step == 0 .and. j_step == 0 .and. k_step == 0) cycle
833 in_idx = base_idx + i_step + j_step * input_h + &
834 k_step * input_hw + 1
835 val_tmp = this%left_operand%val(in_idx, s)
836
837 if (val_tmp .gt. pool_max) then
838 pool_max = val_tmp
839 max_idx = in_idx
840 end if
841 end do
842 end do
843 end do
844
845 ! Assign gradient to max location
846 output(max_idx, s) = output(max_idx, s) + grad_val
847 end do
848
849 end subroutine get_partial_maxpool3d_val
850 !###############################################################################
851
852 end submodule athena__diffstruc_extd_submodule_pool
853