GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_pad.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_pad
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 subroutine fill_edge_region_1d(input, output)
8 !! Fill edge region for 1D padding
9 implicit none
10
11 ! Arguments
12 type(array_type), intent(in) :: input
13 type(array_type), intent(inout) :: output
14
15 ! Local variables
16 integer :: i, m, s, f
17 integer :: step, idx_in, idx_out
18 integer :: input_size, output_size, pad_size
19
20 input_size = input%shape(1)
21 output_size = output%shape(1)
22 pad_size = output%indices(2)
23
24 do f = 1, output%indices(3)
25 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(2) )
26 select case(output%indices(1))
27 case(3, 4) ! circular or reflection
28 step = merge(1, -1, output%indices(1) .eq. 3)
29 do i = 1, pad_size
30 idx_in = output%adj_ja(1,(f-1)*2 + 1) + step * (i - 1) + &
31 (m-1)*input_size
32 idx_out = output%adj_ja(2,(f-1)*2 + 1) + i - 1 + &
33 (m-1)*output_size
34 output%val(idx_out, s) = input%val(idx_in, s)
35 end do
36 case(5) ! replication
37 idx_in = output%adj_ja(1,(f-1)*2 + 1) + (m-1)*input_size
38 do i = 1, pad_size
39 idx_out = output%adj_ja(2,(f-1)*2 + 1) + i - 1 + &
40 (m-1)*output_size
41 output%val(idx_out, s) = input%val(idx_in, s)
42 end do
43 end select
44 end do
45 end do
46
47 end subroutine fill_edge_region_1d
48 !-------------------------------------------------------------------------------
49 pure subroutine accumulate_edge_gradients_1d_val(upstream_grad, output, &
50 input_shape, indices, adj_ja)
51 !! Accumulate edge gradients for 1D padding - raw array version
52 implicit none
53
54 ! Arguments
55 real(real32), dimension(:,:), intent(in) :: upstream_grad
56 real(real32), dimension(:,:), intent(inout) :: output
57 integer, dimension(3), intent(in) :: input_shape
58 integer, dimension(:), intent(in) :: indices
59 integer, dimension(:,:), intent(in) :: adj_ja
60
61 ! Local variables
62 integer :: i, m, s, f
63 integer :: idx_in, idx_out
64 integer :: input_size, output_size
65 integer :: num_facets
66 real(real32) :: grad_sum
67
68 input_size = input_shape(1)
69 output_size = input_size + 2 * indices(2)
70 num_facets = indices(3)
71
72 if(num_facets .eq. 0) return
73
74 select case(indices(1))
75 case(3, 4) ! circular or reflection
76 do f = 1, num_facets
77 do s = 1, input_shape(3)
78 do m = 1, input_shape(2)
79 do i = adj_ja(2,(f-1)*2 + 1), adj_ja(2,(f-1)*2 + 2)
80 idx_out = i + (m-1) * output_size
81 if(indices(1) .eq. 3)then ! circular
82 idx_in = adj_ja(1,(f-1)*2 + 1) + &
83 (i - adj_ja(2,(f-1)*2 + 1)) + (m-1) * input_size
84 else ! reflection
85 idx_in = adj_ja(1,(f-1)*2 + 1) - &
86 (i - adj_ja(2,(f-1)*2 + 1)) + (m-1) * input_size
87 end if
88 output(idx_in, s) = output(idx_in, s) + &
89 upstream_grad(idx_out, s)
90 end do
91 end do
92 end do
93 end do
94 case(5) ! replication
95 do f = 1, num_facets
96 do s = 1, input_shape(3)
97 do m = 1, input_shape(2)
98 grad_sum = 0._real32
99 do i = adj_ja(2,(f-1)*2 + 1), adj_ja(2,(f-1)*2 + 2)
100 idx_out = i + (m-1) * output_size
101 grad_sum = grad_sum + upstream_grad(idx_out, s)
102 end do
103 idx_in = adj_ja(1,(f-1)*2 + 1) + (m-1) * input_size
104 output(idx_in, s) = output(idx_in, s) + grad_sum
105 end do
106 end do
107 end do
108 end select
109
110 end subroutine accumulate_edge_gradients_1d_val
111 !###############################################################################
112
113
114 !###############################################################################
115 subroutine fill_corner_region_2d(input, output)
116 !! Fill corner region for 2D padding
117 implicit none
118
119 ! Arguments
120 type(array_type), intent(in) :: input
121 type(array_type), intent(inout) :: output
122
123 ! Local variables
124 integer :: i, j, m, s, f
125 integer :: step, idx_in, idx_out, idx_shift
126 integer :: input_h, input_w, output_h, output_w
127 integer :: pad_h, pad_w
128 integer, dimension(2,2) :: orig, dest
129
130 input_h = input%shape(1)
131 input_w = input%shape(2)
132 output_h = output%shape(1)
133 output_w = output%shape(2)
134 pad_h = output%indices(2)
135 pad_w = output%indices(3)
136
137 idx_shift = output%indices(4) * 4
138 do f = 1, output%indices(5)
139 orig(1:2,1) = output%adj_ja(1,(f-1)*4 + 1 + idx_shift:(f-1)*4 + 2 + idx_shift)
140 orig(1:2,2) = output%adj_ja(1,(f-1)*4 + 3 + idx_shift:(f-1)*4 + 4 + idx_shift)
141 dest(1:2,1) = output%adj_ja(2,(f-1)*4 + 1 + idx_shift:(f-1)*4 + 2 + idx_shift)
142 dest(1:2,2) = output%adj_ja(2,(f-1)*4 + 3 + idx_shift:(f-1)*4 + 4 + idx_shift)
143
144 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(3) )
145 select case(output%indices(1))
146 case(3, 4) ! circular or reflection
147 step = merge(1, -1, output%indices(1) .eq. 3)
148 do j = dest(1,2), dest(2,2)
149 do i = dest(1,1), dest(2,1)
150 idx_out = i + (j-1) * output_h + (m - 1) * output_h * output_w
151 idx_in = orig(1,1) + step * (i - dest(1,1)) + &
152 (orig(1,2) + step * (j - dest(1,2)) - 1) * input_h + &
153 (m - 1) * input_h * input_w
154 output%val(idx_out, s) = input%val(idx_in, s)
155 end do
156 end do
157 case(5) ! replication
158 idx_in = orig(1,1) + (orig(1,2) - 1) * input_h + &
159 (m - 1) * input_h * input_w
160 do j = dest(1,2), dest(2,2)
161 do i = dest(1,1), dest(2,1)
162 idx_out = i + (j-1) * output_h + (m - 1) * output_h * output_w
163 output%val(idx_out, s) = input%val(idx_in, s)
164 end do
165 end do
166 end select
167 end do
168 end do
169
170 end subroutine fill_corner_region_2d
171 !-------------------------------------------------------------------------------
172 subroutine fill_edge_region_2d(input, output)
173 !! Fill edge region for 2D padding
174 implicit none
175
176 ! Arguments
177 type(array_type), intent(in) :: input
178 type(array_type), intent(inout) :: output
179
180 ! Local variables
181 integer :: i, j, m, s, f, idim
182 integer :: step1, step2, idx_in, idx_out
183 integer :: input_h, input_w, output_h, output_w
184 integer :: pad_h, pad_w
185 integer, dimension(2,2) :: orig, dest
186
187 input_h = input%shape(1)
188 input_w = input%shape(2)
189 output_h = output%shape(1)
190 output_w = output%shape(2)
191 pad_h = output%indices(2)
192 pad_w = output%indices(3)
193
194 do f = 1, output%indices(4)
195 idim = output%indices(5 + f)
196 orig(1:2,1) = output%adj_ja(1,(f-1)*4 + 1:(f-1)*4 + 2)
197 orig(1:2,2) = output%adj_ja(1,(f-1)*4 + 3:(f-1)*4 + 4)
198 dest(1:2,1) = output%adj_ja(2,(f-1)*4 + 1:(f-1)*4 + 2)
199 dest(1:2,2) = output%adj_ja(2,(f-1)*4 + 3:(f-1)*4 + 4)
200
201 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(3) )
202 select case(output%indices(1))
203 case(3, 4) ! circular or reflection
204 step1 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 1)
205 step2 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 2)
206 do j = dest(1,2), dest(2,2)
207 do i = dest(1,1), dest(2,1)
208 idx_out = i + (j-1) * output_h + (m - 1) * output_h * output_w
209 idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
210 (orig(1,2) + step2 * (j - dest(1,2)) - 1) * input_h + &
211 (m - 1) * input_h * input_w
212 output%val(idx_out, s) = input%val(idx_in, s)
213 end do
214 end do
215 case(5) ! replication
216 select case(idim)
217 case(1)
218 do j = dest(1,2), dest(2,2)
219 idx_in = orig(1,1) + (j - dest(1,2)) * input_h + &
220 (m - 1) * input_h * input_w
221 do i = dest(1,1), dest(2,1)
222 idx_out = i + (j-1) * output_h + (m - 1) * output_h * output_w
223 output%val(idx_out, s) = input%val(idx_in, s)
224 end do
225 end do
226 case(2)
227 idx_in = (orig(1,2) - 1) * input_h + (m - 1) * input_h * input_w
228 do j = dest(1,2), dest(2,2)
229 do i = dest(1,1), dest(2,1)
230 idx_out = i + (j-1) * output_h + (m - 1) * output_h * output_w
231 output%val(idx_out, s) = &
232 input%val(idx_in + i - dest(1,1) + 1, s)
233 end do
234 end do
235 end select
236 end select
237 end do
238 end do
239
240 end subroutine fill_edge_region_2d
241 !-------------------------------------------------------------------------------
242 pure subroutine accumulate_corner_gradients_2d_val(upstream_grad, output, &
243 input_shape, indices, adj_ja)
244 !! Accumulate corner gradients for 2D padding - raw array version
245 implicit none
246
247 ! Arguments
248 real(real32), dimension(:,:), intent(in) :: upstream_grad
249 real(real32), dimension(:,:), intent(inout) :: output
250 integer, dimension(4), intent(in) :: input_shape
251 integer, dimension(:), intent(in) :: indices
252 integer, dimension(:,:), intent(in) :: adj_ja
253
254 ! Local variables
255 integer :: i, j, m, s, f
256 integer :: idx_in, idx_out
257 integer :: input_size_h, input_size_w
258 integer :: output_size_h, output_size_w
259 integer :: num_edge_facets, num_corner_facets
260 integer :: adj_ja_offset
261 real(real32) :: grad_sum
262
263 input_size_h = input_shape(1)
264 input_size_w = input_shape(2)
265 output_size_h = input_size_h + 2 * indices(2)
266 output_size_w = input_size_w + 2 * indices(3)
267 num_edge_facets = indices(4)
268 num_corner_facets = indices(5)
269 adj_ja_offset = num_edge_facets * 4
270
271 if(num_corner_facets .eq. 0) return
272
273 select case(indices(1))
274 case(3) ! circular
275 do f = 1, num_corner_facets
276 do concurrent( &
277 s = 1:input_shape(4), &
278 m = 1:input_shape(3), &
279 j = adj_ja(1,(f-1)*4 + 3 + adj_ja_offset) : &
280 adj_ja(1,(f-1)*4 + 4 + adj_ja_offset), &
281 i = adj_ja(1,(f-1)*4 + 1 + adj_ja_offset) : &
282 adj_ja(1,(f-1)*4 + 2 + adj_ja_offset) &
283 )
284 idx_in = i + (j-1) * input_size_h + &
285 (m-1) * input_size_h * input_size_w
286 idx_out = ( &
287 adj_ja(2,(f-1)*4 + 1 + adj_ja_offset) + &
288 (i - adj_ja(1,(f-1)*4 + 1 + adj_ja_offset)) &
289 ) + ( &
290 adj_ja(2,(f-1)*4 + 3 + adj_ja_offset) + &
291 (j - adj_ja(1,(f-1)*4 + 3 + adj_ja_offset)) - 1 &
292 ) * output_size_h + (m-1) * output_size_h * output_size_w
293 output(idx_in, s) = output(idx_in, s) + upstream_grad(idx_out, s)
294 end do
295 end do
296 case(4) ! reflection
297 do f = 1, num_corner_facets
298 do concurrent( &
299 s = 1:input_shape(4), &
300 m = 1:input_shape(3), &
301 j = adj_ja(1,(f-1)*4 + 4 + adj_ja_offset) : &
302 adj_ja(1,(f-1)*4 + 3 + adj_ja_offset), &
303 i = adj_ja(1,(f-1)*4 + 2 + adj_ja_offset) : &
304 adj_ja(1,(f-1)*4 + 1 + adj_ja_offset) &
305 )
306 idx_in = i + (j-1) * input_size_h + &
307 (m-1) * input_size_h * input_size_w
308 idx_out = ( &
309 adj_ja(2,(f-1)*4 + 1 + adj_ja_offset) - &
310 (i - adj_ja(1,(f-1)*4 + 1 + adj_ja_offset)) &
311 ) + ( &
312 adj_ja(2,(f-1)*4 + 4 + adj_ja_offset) - &
313 (j - adj_ja(1,(f-1)*4 + 4 + adj_ja_offset)) - 1 &
314 ) * output_size_h + (m-1) * output_size_h * output_size_w
315 output(idx_in, s) = output(idx_in, s) + upstream_grad(idx_out, s)
316 end do
317 end do
318 case(5) ! replication
319 do f = 1, num_corner_facets
320 do s = 1, input_shape(4)
321 do m = 1, input_shape(3)
322 grad_sum = 0._real32
323 do j = adj_ja(2,(f-1)*4 + 3 + adj_ja_offset), &
324 adj_ja(2,(f-1)*4 + 4 + adj_ja_offset)
325 do i = adj_ja(2,(f-1)*4 + 1 + adj_ja_offset), &
326 adj_ja(2,(f-1)*4 + 2 + adj_ja_offset)
327 idx_out = i + (j-1) * output_size_h + &
328 (m-1) * output_size_h * output_size_w
329 grad_sum = grad_sum + upstream_grad(idx_out, s)
330 end do
331 end do
332 idx_in = adj_ja(1,(f-1)*4 + 1 + adj_ja_offset) + &
333 (adj_ja(1,(f-1)*4 + 3 + adj_ja_offset) - 1) * &
334 input_size_h + (m-1) * input_size_h * input_size_w
335 output(idx_in, s) = output(idx_in, s) + grad_sum
336 end do
337 end do
338 end do
339 end select
340
341 end subroutine accumulate_corner_gradients_2d_val
342 !-------------------------------------------------------------------------------
343 pure subroutine accumulate_edge_gradients_2d_val(upstream_grad, output, &
344 input_shape, indices, adj_ja)
345 !! Accumulate edge gradients for 2D padding - raw array version
346 implicit none
347
348 ! Arguments
349 real(real32), dimension(:,:), intent(in) :: upstream_grad
350 real(real32), dimension(:,:), intent(inout) :: output
351 integer, dimension(4), intent(in) :: input_shape
352 integer, dimension(:), intent(in) :: indices
353 integer, dimension(:,:), intent(in) :: adj_ja
354
355 ! Local variables
356 integer :: i, j, m, s, f, idx
357 integer :: idx_in, idx_out
358 integer :: input_size_h, input_size_w
359 integer :: output_size_h, output_size_w
360 integer :: num_edge_facets
361 integer :: facet_dim
362 real(real32) :: grad_sum
363
364 input_size_h = input_shape(1)
365 input_size_w = input_shape(2)
366 output_size_h = input_size_h + 2 * indices(2)
367 output_size_w = input_size_w + 2 * indices(3)
368 num_edge_facets = indices(4)
369
370 if(num_edge_facets .eq. 0) return
371
372 select case(indices(1))
373 case(3) ! circular
374 do f = 1, num_edge_facets
375 facet_dim = indices(5 + f)
376 if(facet_dim .eq. 1)then
377 do concurrent( &
378 s = 1:input_shape(4), &
379 m = 1:input_shape(3), &
380 j = adj_ja(1,(f-1)*4 + 3):adj_ja(1,(f-1)*4 + 4), &
381 i = adj_ja(1,(f-1)*4 + 1):adj_ja(1,(f-1)*4 + 2))
382 idx_in = i + (j-1) * input_size_h + &
383 (m-1) * input_size_h * input_size_w
384 idx_out = &
385 ( &
386 adj_ja(2,(f-1)*4 + 1) + &
387 (i - adj_ja(1,(f-1)*4 + 1)) &
388 ) + &
389 (j + adj_ja(2,(f-1)*4 + 3) - adj_ja(1,(f-1)*4 + 3) - 1) * &
390 output_size_h + (m-1) * output_size_h * output_size_w
391 output(idx_in, s) = output(idx_in, s) + &
392 upstream_grad(idx_out, s)
393 end do
394 else
395 do concurrent( &
396 s = 1:input_shape(4), &
397 m = 1:input_shape(3), &
398 j = adj_ja(1,(f-1)*4 + 3):adj_ja(1,(f-1)*4 + 4), &
399 i = adj_ja(1,(f-1)*4 + 1):adj_ja(1,(f-1)*4 + 2))
400 idx_in = i + (j-1) * input_size_h + &
401 (m-1) * input_size_h * input_size_w
402 idx_out = &
403 ( &
404 i + adj_ja(2,(f-1)*4 + 1) - &
405 adj_ja(1,(f-1)*4 + 1) &
406 ) + ( &
407 adj_ja(2,(f-1)*4 + 3) + &
408 (j - adj_ja(1,(f-1)*4 + 3)) - 1 &
409 ) * output_size_h + &
410 (m-1) * output_size_h * output_size_w
411 output(idx_in, s) = output(idx_in, s) + &
412 upstream_grad(idx_out, s)
413 end do
414 end if
415 end do
416 case(4) ! reflection
417 do f = 1, num_edge_facets
418 facet_dim = indices(5 + f)
419 if(facet_dim .eq. 1)then
420 do concurrent( &
421 s = 1:input_shape(4), &
422 m = 1:input_shape(3), &
423 j = adj_ja(1,(f-1)*4 + 3):adj_ja(1,(f-1)*4 + 4), &
424 i = adj_ja(1,(f-1)*4 + 2):adj_ja(1,(f-1)*4 + 1))
425 idx_in = i + (j-1) * input_size_h + &
426 (m-1) * input_size_h * input_size_w
427 idx_out = &
428 ( &
429 adj_ja(2,(f-1)*4 + 1) - &
430 (i - adj_ja(1,(f-1)*4 + 1)) &
431 ) + &
432 (j + adj_ja(2,(f-1)*4 + 3) - adj_ja(1,(f-1)*4 + 3) - 1) * &
433 output_size_h + (m-1) * output_size_h * output_size_w
434 output(idx_in, s) = output(idx_in, s) + &
435 upstream_grad(idx_out, s)
436 end do
437 else
438 do concurrent( &
439 s = 1:input_shape(4), &
440 m = 1:input_shape(3), &
441 j = adj_ja(1,(f-1)*4 + 4):adj_ja(1,(f-1)*4 + 3), &
442 i = adj_ja(1,(f-1)*4 + 1):adj_ja(1,(f-1)*4 + 2))
443 idx_in = i + (j-1) * input_size_h + &
444 (m-1) * input_size_h * input_size_w
445 idx_out = &
446 ( &
447 i + adj_ja(2,(f-1)*4 + 1) - &
448 adj_ja(1,(f-1)*4 + 1) &
449 ) + ( &
450 adj_ja(2,(f-1)*4 + 4) - &
451 (j - adj_ja(1,(f-1)*4 + 4)) - 1 &
452 ) * output_size_h + &
453 (m-1) * output_size_h * output_size_w
454 output(idx_in, s) = output(idx_in, s) + &
455 upstream_grad(idx_out, s)
456 end do
457 end if
458 end do
459 case(5) ! replication
460 do f = 1, num_edge_facets
461 facet_dim = indices(5 + f)
462 if(facet_dim .eq. 1)then
463 do s = 1, input_shape(4)
464 do m = 1, input_shape(3)
465 do j = adj_ja(1,(f-1)*4 + 3), adj_ja(1,(f-1)*4 + 4)
466 grad_sum = 0._real32
467 do i = adj_ja(2,(f-1)*4 + 1), adj_ja(2,(f-1)*4 + 2)
468 idx_out = i + &
469 ( &
470 j + adj_ja(2,(f-1)*4 + 3) - &
471 adj_ja(1,(f-1)*4 + 3) - 1 &
472 ) * output_size_h + &
473 (m-1) * output_size_h * output_size_w
474 grad_sum = grad_sum + upstream_grad(idx_out, s)
475 end do
476 idx_in = adj_ja(1,(f-1)*4 + 1) + (j-1) * input_size_h + &
477 (m-1) * input_size_h * input_size_w
478 output(idx_in, s) = output(idx_in, s) + grad_sum
479 end do
480 end do
481 end do
482 else
483 do s = 1, input_shape(4)
484 do m = 1, input_shape(3)
485 do i = adj_ja(1,(f-1)*4 + 1), adj_ja(1,(f-1)*4 + 2)
486 grad_sum = 0._real32
487 do j = adj_ja(2,(f-1)*4 + 3), adj_ja(2,(f-1)*4 + 4)
488 idx_out = &
489 ( i + adj_ja(2,(f-1)*4 + 1) - adj_ja(1,(f-1)*4 + 1) ) + &
490 (j-1) * output_size_h + &
491 (m-1) * output_size_h * output_size_w
492 grad_sum = grad_sum + upstream_grad(idx_out, s)
493 end do
494 idx_in = i + (adj_ja(1,(f-1)*4 + 3) - 1) * &
495 input_size_h + (m-1) * input_size_h * input_size_w
496 output(idx_in, s) = output(idx_in, s) + grad_sum
497 end do
498 end do
499 end do
500 end if
501 end do
502 end select
503
504 end subroutine accumulate_edge_gradients_2d_val
505 !###############################################################################
506
507
508 !###############################################################################
509 subroutine fill_corner_region_3d(input, output)
510 !! Fill corner region for 3D padding
511 implicit none
512
513 ! Arguments
514 type(array_type), intent(in) :: input
515 type(array_type), intent(inout) :: output
516
517 ! Local variables
518 integer :: i, j, k, m, s, f
519 integer :: step, idx_in, idx_out, idx_shift
520 integer :: input_h, input_w, input_d
521 integer :: output_h, output_w, output_d
522 integer, dimension(2,3) :: orig, dest
523
524 input_h = input%shape(1)
525 input_w = input%shape(2)
526 input_d = input%shape(3)
527 output_h = output%shape(1)
528 output_w = output%shape(2)
529 output_d = output%shape(3)
530
531 idx_shift = ( output%indices(5) + output%indices(6) ) * 6
532 do f = 1, output%indices(7)
533 orig(1:2,1) = output%adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
534 orig(1:2,2) = output%adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
535 orig(1:2,3) = output%adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
536 dest(1:2,1) = output%adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
537 dest(1:2,2) = output%adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
538 dest(1:2,3) = output%adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
539
540 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(4) )
541 select case(output%indices(1))
542 case(3, 4) ! circular or reflection
543 step = merge(1, -1, output%indices(1) .eq. 3)
544 do k = dest(1,3), dest(2,3)
545 do j = dest(1,2), dest(2,2)
546 do i = dest(1,1), dest(2,1)
547 idx_out = i + (j-1) * output_h + &
548 (k-1) * output_h * output_w + &
549 (m - 1) * output_h * output_w * output_d
550 idx_in = orig(1,1) + step * (i - dest(1,1)) + &
551 (orig(1,2) + step * (j - dest(1,2)) - 1) * input_h + &
552 (orig(1,3) + step * (k - dest(1,3)) - 1) * &
553 input_h * input_w + &
554 (m - 1) * input_h * input_w * input_d
555 output%val(idx_out, s) = input%val(idx_in, s)
556 end do
557 end do
558 end do
559 case(5) ! replication
560 idx_in = orig(1,1) + &
561 (orig(1,2) - 1) * input_h + &
562 (orig(1,3) - 1) * input_h * input_w + &
563 (m - 1) * input_h * input_w * input_d
564 do k = dest(1,3), dest(2,3)
565 do j = dest(1,2), dest(2,2)
566 do i = dest(1,1), dest(2,1)
567 idx_out = i + (j - 1) * output_h + &
568 (k - 1) * output_h * output_w + &
569 (m - 1) * output_h * output_w * output_d
570 output%val(idx_out, s) = input%val(idx_in, s)
571 end do
572 end do
573 end do
574 end select
575 end do
576 end do
577
578 end subroutine fill_corner_region_3d
579 !-------------------------------------------------------------------------------
580 subroutine fill_edge_region_3d(input, output)
581 !! Fill edge region for 3D padding
582 implicit none
583
584 ! Arguments
585 type(array_type), intent(in) :: input
586 type(array_type), intent(inout) :: output
587
588 ! Local variables
589 integer :: i, j, k, m, s, f, idim
590 integer :: step1, step2, step3, idx_in, idx_out, idx_shift
591 integer :: input_h, input_w, input_d
592 integer :: output_h, output_w, output_d
593 integer :: pad_h, pad_w, pad_d
594 integer, dimension(2,3) :: orig, dest
595
596 input_h = input%shape(1)
597 input_w = input%shape(2)
598 input_d = input%shape(3)
599 output_h = output%shape(1)
600 output_w = output%shape(2)
601 output_d = output%shape(3)
602 pad_h = output%indices(2)
603 pad_w = output%indices(3)
604 pad_d = output%indices(4)
605
606 idx_shift = output%indices(5) * 6
607 do f = 1, output%indices(6)
608 idim = output%indices(7 + output%indices(5) + f)
609 orig(1:2,1) = output%adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
610 orig(1:2,2) = output%adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
611 orig(1:2,3) = output%adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
612 dest(1:2,1) = output%adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
613 dest(1:2,2) = output%adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
614 dest(1:2,3) = output%adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
615
616 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(4) )
617 select case(output%indices(1))
618 case(3, 4) ! circular or reflection
619 step1 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 1)
620 step2 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 2)
621 step3 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 3)
622 do k = dest(1,3), dest(2,3)
623 do j = dest(1,2), dest(2,2)
624 do i = dest(1,1), dest(2,1)
625 idx_out = i + (j-1) * output_h + &
626 (k-1) * output_h * output_w + &
627 (m - 1) * output_h * output_w * output_d
628 idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
629 (orig(1,2) + step2 * (j - dest(1,2)) - 1) * &
630 input_h + &
631 (orig(1,3) + step3 * (k - dest(1,3)) - 1) * &
632 input_h * input_w + &
633 (m - 1) * input_h * input_w * input_d
634 output%val(idx_out, s) = input%val(idx_in, s)
635 end do
636 end do
637 end do
638 case(5) ! replication
639 select case(idim)
640 case(1) ! Edge along dimension 1
641 do i = dest(1,1), dest(2,1)
642 idx_in = i - dest(1,1) + 1 + &
643 (orig(1,2) - 1) * input_h + &
644 (orig(1,3) - 1) * input_h * input_w + &
645 (m - 1) * input_h * input_w * input_d
646 do k = dest(1,3), dest(2,3)
647 do j = dest(1,2), dest(2,2)
648 idx_out = i + (j - 1) * output_h + &
649 (k - 1) * output_h * output_w + &
650 (m - 1) * output_h * output_w * output_d
651 output%val(idx_out, s) = input%val(idx_in, s)
652 end do
653 end do
654 end do
655 case(2) ! Edge along dimension 2
656 do j = dest(1,2), dest(2,2)
657 idx_in = orig(1,1) + &
658 (j - dest(1,2)) * input_h + &
659 (orig(1,3) - 1) * input_h * input_w + &
660 (m - 1) * input_h * input_w * input_d
661 do k = dest(1,3), dest(2,3)
662 do i = dest(1,1), dest(2,1)
663 idx_out = i + (j - 1) * output_h + &
664 (k - 1) * output_h * output_w + &
665 (m - 1) * output_h * output_w * output_d
666 output%val(idx_out, s) = input%val(idx_in, s)
667 end do
668 end do
669 end do
670 case(3) ! Edge along dimension 3
671 do k = dest(1,3), dest(2,3)
672 idx_in = orig(1,1) + &
673 (orig(1,2) - 1) * input_h + &
674 (k - dest(1,3)) * input_h * input_w + &
675 (m - 1) * input_h * input_w * input_d
676 do j = dest(1,2), dest(2,2)
677 do i = dest(1,1), dest(2,1)
678 idx_out = i + (j - 1) * output_h + &
679 (k - 1) * output_h * output_w + &
680 (m - 1) * output_h * output_w * output_d
681 output%val(idx_out, s) = input%val(idx_in, s)
682 end do
683 end do
684 end do
685 end select
686 end select
687 end do
688 end do
689
690 end subroutine fill_edge_region_3d
691 !-------------------------------------------------------------------------------
692 subroutine fill_face_region_3d(input, output)
693 !! Fill face region for 3D padding
694 implicit none
695
696 ! Arguments
697 type(array_type), intent(in) :: input
698 type(array_type), intent(inout) :: output
699
700 ! Local variables
701 integer :: i, j, k, m, s, f, idim
702 integer :: step1, step2, step3, idx_in, idx_out
703 integer :: input_h, input_w, input_d
704 integer :: output_h, output_w, output_d
705 integer, dimension(2,3) :: orig, dest
706
707 input_h = input%shape(1)
708 input_w = input%shape(2)
709 input_d = input%shape(3)
710 output_h = output%shape(1)
711 output_w = output%shape(2)
712 output_d = output%shape(3)
713
714 do f = 1, output%indices(5)
715 idim = output%indices(7 + f)
716 orig(1:2,1) = output%adj_ja(1,(f-1)*6 + 1:(f-1)*6 + 2)
717 orig(1:2,2) = output%adj_ja(1,(f-1)*6 + 3:(f-1)*6 + 4)
718 orig(1:2,3) = output%adj_ja(1,(f-1)*6 + 5:(f-1)*6 + 6)
719 dest(1:2,1) = output%adj_ja(2,(f-1)*6 + 1:(f-1)*6 + 2)
720 dest(1:2,2) = output%adj_ja(2,(f-1)*6 + 3:(f-1)*6 + 4)
721 dest(1:2,3) = output%adj_ja(2,(f-1)*6 + 5:(f-1)*6 + 6)
722
723 do concurrent( s = 1:size(output%val, dim=2), m = 1:output%shape(4) )
724 select case(output%indices(1))
725 case(3, 4) ! circular or reflection
726 step1 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 1)
727 step2 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 2)
728 step3 = merge(-1, 1, output%indices(1) .eq. 4 .and. idim .eq. 3)
729 do k = dest(1,3), dest(2,3)
730 do j = dest(1,2), dest(2,2)
731 do i = dest(1,1), dest(2,1)
732 idx_out = i + (j-1) * output_h + &
733 (k-1) * output_h * output_w + &
734 (m - 1) * output_h * output_w * output_d
735 idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
736 (orig(1,2) + step2 * (j - dest(1,2)) - 1) * &
737 input_h + &
738 (orig(1,3) + step3 * (k - dest(1,3)) - 1) * &
739 input_h * input_w + &
740 (m - 1) * input_h * input_w * input_d
741 output%val(idx_out, s) = input%val(idx_in, s)
742 end do
743 end do
744 end do
745 case(5) ! replication
746 select case(idim)
747 case(1) ! Face perpendicular to dimension 1
748 do k = dest(1,3), dest(2,3)
749 do j = dest(1,2), dest(2,2)
750 idx_in = orig(1,1) + &
751 (j - dest(1,2) + orig(1,2) - 1) * input_h + &
752 (k - dest(1,3) + orig(1,3) - 1) * input_h * input_w + &
753 (m - 1) * input_h * input_w * input_d
754 do i = dest(1,1), dest(2,1)
755 idx_out = i + (j - 1) * output_h + &
756 (k - 1) * output_h * output_w + &
757 (m - 1) * output_h * output_w * output_d
758 output%val(idx_out, s) = input%val(idx_in, s)
759 end do
760 end do
761 end do
762 case(2) ! Face perpendicular to dimension 2
763 do k = dest(1,3), dest(2,3)
764 do i = dest(1,1), dest(2,1)
765 idx_in = i - dest(1,1) + orig(1,1) + &
766 (orig(1,2) - 1) * input_h + &
767 (k - dest(1,3) + orig(1,3) - 1) * input_h * input_w + &
768 (m - 1) * input_h * input_w * input_d
769 do j = dest(1,2), dest(2,2)
770 idx_out = i + (j - 1) * output_h + &
771 (k - 1) * output_h * output_w + &
772 (m - 1) * output_h * output_w * output_d
773 output%val(idx_out, s) = input%val(idx_in, s)
774 end do
775 end do
776 end do
777 case(3) ! Face perpendicular to dimension 3
778 do j = dest(1,2), dest(2,2)
779 do i = dest(1,1), dest(2,1)
780 idx_in = i - dest(1,1) + orig(1,1) + &
781 (j - dest(1,2) + orig(1,2) - 1) * input_h + &
782 (orig(1,3) - 1) * input_h * input_w + &
783 (m - 1) * input_h * input_w * input_d
784 do k = dest(1,3), dest(2,3)
785 idx_out = i + (j - 1) * output_h + &
786 (k - 1) * output_h * output_w + &
787 (m - 1) * output_h * output_w * output_d
788 output%val(idx_out, s) = input%val(idx_in, s)
789 end do
790 end do
791 end do
792 end select
793 end select
794 end do
795 end do
796
797 end subroutine fill_face_region_3d
798 !-------------------------------------------------------------------------------
799 pure subroutine accumulate_corner_gradients_3d_val(upstream_grad, output, &
800 input_shape, indices, adj_ja)
801 !! Accumulate corner gradients for 3D padding - raw array version
802 implicit none
803
804 ! Arguments
805 real(real32), dimension(:,:), intent(in) :: upstream_grad
806 real(real32), dimension(:,:), intent(inout) :: output
807 integer, dimension(5), intent(in) :: input_shape
808 integer, dimension(:), intent(in) :: indices
809 integer, dimension(:,:), intent(in) :: adj_ja
810
811 ! Local variables
812 integer :: i, j, k, m, s, f
813 integer :: step, idx_in, idx_out, idx_shift
814 integer :: input_h, input_w, input_d
815 integer :: output_h, output_w, output_d
816 integer, dimension(2,3) :: orig, dest
817 real(real32) :: grad_sum
818
819 input_h = input_shape(1)
820 input_w = input_shape(2)
821 input_d = input_shape(3)
822 output_h = input_h + 2 * indices(2)
823 output_w = input_w + 2 * indices(3)
824 output_d = input_d + 2 * indices(4)
825
826 if(indices(7) .eq. 0) return
827
828 idx_shift = ( indices(5) + indices(6) ) * 6
829
830 select case(indices(1))
831 case(3, 4) ! circular or reflection
832 step = merge(1, -1, indices(1) .eq. 3)
833 do f = 1, indices(7)
834 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
835 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
836 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
837 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
838 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
839 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
840
841 do s = 1, input_shape(5)
842 do m = 1, input_shape(4)
843 do k = dest(1,3), dest(2,3)
844 do j = dest(1,2), dest(2,2)
845 do i = dest(1,1), dest(2,1)
846 idx_out = i + (j-1) * output_h + &
847 (k-1) * output_h * output_w + &
848 (m - 1) * output_h * output_w * output_d
849 idx_in = orig(1,1) + step * (i - dest(1,1)) + &
850 (orig(1,2) + step * (j - dest(1,2)) - 1) * &
851 input_h + &
852 (orig(1,3) + step * (k - dest(1,3)) - 1) * &
853 input_h * input_w + &
854 (m - 1) * input_h * input_w * input_d
855 output(idx_in, s) = output(idx_in, s) + &
856 upstream_grad(idx_out, s)
857 end do
858 end do
859 end do
860 end do
861 end do
862 end do
863 case(5) ! replication
864 do f = 1, indices(7)
865 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
866 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
867 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
868 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
869 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
870 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
871
872 do s = 1, input_shape(5)
873 do m = 1, input_shape(4)
874 grad_sum = 0._real32
875 do k = dest(1,3), dest(2,3)
876 do j = dest(1,2), dest(2,2)
877 do i = dest(1,1), dest(2,1)
878 idx_out = i + (j-1) * output_h + &
879 (k-1) * output_h * output_w + &
880 (m - 1) * output_h * output_w * output_d
881 grad_sum = grad_sum + upstream_grad(idx_out, s)
882 end do
883 end do
884 end do
885 idx_in = orig(1,1) + (orig(1,2) - 1) * input_h + &
886 (orig(1,3) - 1) * input_h * input_w + &
887 (m - 1) * input_h * input_w * input_d
888 output(idx_in, s) = output(idx_in, s) + grad_sum
889 end do
890 end do
891 end do
892 end select
893
894 end subroutine accumulate_corner_gradients_3d_val
895 !-------------------------------------------------------------------------------
896 pure subroutine accumulate_edge_gradients_3d_val(upstream_grad, output, &
897 input_shape, indices, adj_ja)
898 !! Accumulate edge gradients for 3D padding - raw array version
899 implicit none
900
901 ! Arguments
902 real(real32), dimension(:,:), intent(in) :: upstream_grad
903 real(real32), dimension(:,:), intent(inout) :: output
904 integer, dimension(5), intent(in) :: input_shape
905 integer, dimension(:), intent(in) :: indices
906 integer, dimension(:,:), intent(in) :: adj_ja
907
908 ! Local variables
909 integer :: i, j, k, m, s, f, idim
910 integer :: step1, step2, step3, idx_in, idx_out, idx_shift
911 integer :: input_h, input_w, input_d
912 integer :: output_h, output_w, output_d
913 integer, dimension(2,3) :: orig, dest
914 real(real32) :: grad_sum
915
916 input_h = input_shape(1)
917 input_w = input_shape(2)
918 input_d = input_shape(3)
919 output_h = input_h + 2 * indices(2)
920 output_w = input_w + 2 * indices(3)
921 output_d = input_d + 2 * indices(4)
922
923 if(indices(6) .eq. 0) return
924
925 idx_shift = indices(5) * 6
926
927 select case(indices(1))
928 case(3, 4) ! circular or reflection
929 do f = 1, indices(6)
930 idim = indices(7 + indices(5) + f)
931 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
932 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
933 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
934 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
935 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
936 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
937
938 step1 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 1)
939 step2 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 2)
940 step3 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 3)
941
942 do s = 1, input_shape(5)
943 do m = 1, input_shape(4)
944 do k = dest(1,3), dest(2,3)
945 do j = dest(1,2), dest(2,2)
946 do i = dest(1,1), dest(2,1)
947 idx_out = i + (j-1) * output_h + &
948 (k-1) * output_h * output_w + &
949 (m - 1) * output_h * output_w * output_d
950 idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
951 (orig(1,2) + step2 * (j - dest(1,2)) - 1) * &
952 input_h + &
953 (orig(1,3) + step3 * (k - dest(1,3)) - 1) * &
954 input_h * input_w + &
955 (m - 1) * input_h * input_w * input_d
956 output(idx_in, s) = output(idx_in, s) + &
957 upstream_grad(idx_out, s)
958 end do
959 end do
960 end do
961 end do
962 end do
963 end do
964 case(5) ! replication
965 do f = 1, indices(6)
966 idim = indices(7 + indices(5) + f)
967 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
968 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
969 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
970 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1 + idx_shift:(f-1)*6 + 2 + idx_shift)
971 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3 + idx_shift:(f-1)*6 + 4 + idx_shift)
972 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5 + idx_shift:(f-1)*6 + 6 + idx_shift)
973
974 select case(idim)
975 case(1) ! Edge along dimension 1
976 do s = 1, input_shape(5)
977 do m = 1, input_shape(4)
978 do i = dest(1,1), dest(2,1)
979 idx_in = i - dest(1,1) + 1 + &
980 (orig(1,2) - 1) * input_h + &
981 (orig(1,3) - 1) * input_h * input_w + &
982 (m - 1) * input_h * input_w * input_d
983 grad_sum = 0._real32
984 do k = dest(1,3), dest(2,3)
985 do j = dest(1,2), dest(2,2)
986 idx_out = i + (j - 1) * output_h + &
987 (k - 1) * output_h * output_w + &
988 (m - 1) * output_h * output_w * output_d
989 grad_sum = grad_sum + upstream_grad(idx_out, s)
990 end do
991 end do
992 output(idx_in, s) = output(idx_in, s) + grad_sum
993 end do
994 end do
995 end do
996 case(2) ! Edge along dimension 2
997 do s = 1, input_shape(5)
998 do m = 1, input_shape(4)
999 do j = dest(1,2), dest(2,2)
1000 idx_in = orig(1,1) + &
1001 (j - dest(1,2)) * input_h + &
1002 (orig(1,3) - 1) * input_h * input_w + &
1003 (m - 1) * input_h * input_w * input_d
1004 grad_sum = 0._real32
1005 do k = dest(1,3), dest(2,3)
1006 do i = dest(1,1), dest(2,1)
1007 idx_out = i + (j - 1) * output_h + &
1008 (k - 1) * output_h * output_w + &
1009 (m - 1) * output_h * output_w * output_d
1010 grad_sum = grad_sum + upstream_grad(idx_out, s)
1011 end do
1012 end do
1013 output(idx_in, s) = output(idx_in, s) + grad_sum
1014 end do
1015 end do
1016 end do
1017 case(3) ! Edge along dimension 3
1018 do s = 1, input_shape(5)
1019 do m = 1, input_shape(4)
1020 do k = dest(1,3), dest(2,3)
1021 idx_in = orig(1,1) + &
1022 (orig(1,2) - 1) * input_h + &
1023 (k - dest(1,3)) * input_h * input_w + &
1024 (m - 1) * input_h * input_w * input_d
1025 grad_sum = 0._real32
1026 do j = dest(1,2), dest(2,2)
1027 do i = dest(1,1), dest(2,1)
1028 idx_out = i + (j - 1) * output_h + &
1029 (k - 1) * output_h * output_w + &
1030 (m - 1) * output_h * output_w * output_d
1031 grad_sum = grad_sum + upstream_grad(idx_out, s)
1032 end do
1033 end do
1034 output(idx_in, s) = output(idx_in, s) + grad_sum
1035 end do
1036 end do
1037 end do
1038 end select
1039 end do
1040 end select
1041
1042 end subroutine accumulate_edge_gradients_3d_val
1043 !-------------------------------------------------------------------------------
1044 pure subroutine accumulate_face_gradients_3d_val(upstream_grad, output, &
1045 input_shape, indices, adj_ja)
1046 !! Accumulate face gradients for 3D padding - raw array version
1047 implicit none
1048
1049 ! Arguments
1050 real(real32), dimension(:,:), intent(in) :: upstream_grad
1051 real(real32), dimension(:,:), intent(inout) :: output
1052 integer, dimension(5), intent(in) :: input_shape
1053 integer, dimension(:), intent(in) :: indices
1054 integer, dimension(:,:), intent(in) :: adj_ja
1055
1056 ! Local variables
1057 integer :: i, j, k, m, s, f, idim
1058 integer :: step1, step2, step3, idx_in, idx_out
1059 integer :: input_h, input_w, input_d
1060 integer :: output_h, output_w, output_d
1061 integer, dimension(2,3) :: orig, dest
1062 real(real32) :: grad_sum
1063
1064 input_h = input_shape(1)
1065 input_w = input_shape(2)
1066 input_d = input_shape(3)
1067 output_h = input_h + 2 * indices(2)
1068 output_w = input_w + 2 * indices(4)
1069 output_d = input_d + 2 * indices(4)
1070
1071 if(indices(5) .eq. 0) return
1072
1073 select case(indices(1))
1074 case(3, 4) ! circular or reflection
1075 do f = 1, indices(5)
1076 idim = indices(7 + f)
1077 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1:(f-1)*6 + 2)
1078 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3:(f-1)*6 + 4)
1079 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5:(f-1)*6 + 6)
1080 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1:(f-1)*6 + 2)
1081 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3:(f-1)*6 + 4)
1082 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5:(f-1)*6 + 6)
1083
1084 step1 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 1)
1085 step2 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 2)
1086 step3 = merge(-1, 1, indices(1) .eq. 4 .and. idim .eq. 3)
1087
1088 do s = 1, input_shape(5)
1089 do m = 1, input_shape(4)
1090 do k = dest(1,3), dest(2,3)
1091 do j = dest(1,2), dest(2,2)
1092 do i = dest(1,1), dest(2,1)
1093 idx_out = i + (j-1) * output_h + &
1094 (k-1) * output_h * output_w + &
1095 (m - 1) * output_h * output_w * output_d
1096 idx_in = orig(1,1) + step1 * (i - dest(1,1)) + &
1097 (orig(1,2) + step2 * (j - dest(1,2)) - 1) * &
1098 input_h + &
1099 (orig(1,3) + step3 * (k - dest(1,3)) - 1) * &
1100 input_h * input_w + &
1101 (m - 1) * input_h * input_w * input_d
1102 output(idx_in, s) = output(idx_in, s) + &
1103 upstream_grad(idx_out, s)
1104 end do
1105 end do
1106 end do
1107 end do
1108 end do
1109 end do
1110 case(5) ! replication
1111 do f = 1, indices(5)
1112 idim = indices(7 + f)
1113 orig(1:2,1) = adj_ja(1,(f-1)*6 + 1:(f-1)*6 + 2)
1114 orig(1:2,2) = adj_ja(1,(f-1)*6 + 3:(f-1)*6 + 4)
1115 orig(1:2,3) = adj_ja(1,(f-1)*6 + 5:(f-1)*6 + 6)
1116 dest(1:2,1) = adj_ja(2,(f-1)*6 + 1:(f-1)*6 + 2)
1117 dest(1:2,2) = adj_ja(2,(f-1)*6 + 3:(f-1)*6 + 4)
1118 dest(1:2,3) = adj_ja(2,(f-1)*6 + 5:(f-1)*6 + 6)
1119
1120 select case(idim)
1121 case(1) ! Face perpendicular to dimension 1
1122 do s = 1, input_shape(5)
1123 do m = 1, input_shape(4)
1124 do k = dest(1,3), dest(2,3)
1125 do j = dest(1,2), dest(2,2)
1126 idx_in = orig(1,1) + &
1127 ( j - dest(1,2) ) * input_h + &
1128 ( k - dest(1,3) ) * input_h * input_w + &
1129 (m - 1) * input_h * input_w * input_d
1130 grad_sum = 0._real32
1131 do i = dest(1,1), dest(2,1)
1132 idx_out = i + (j-1) * output_h + &
1133 (k-1) * output_h * output_w + &
1134 (m - 1) * output_h * output_w * output_d
1135 grad_sum = grad_sum + upstream_grad(idx_out, s)
1136 end do
1137 output(idx_in, s) = output(idx_in, s) + grad_sum
1138 end do
1139 end do
1140 end do
1141 end do
1142 case(2) ! Face perpendicular to dimension 2
1143 do s = 1, input_shape(5)
1144 do m = 1, input_shape(4)
1145 do k = dest(1,3), dest(2,3)
1146 do i = dest(1,1), dest(2,1)
1147 idx_in = i - dest(1,1) + 1 + &
1148 ( k - dest(1,3) ) * input_h * input_w + &
1149 (m - 1) * input_h * input_w * input_d
1150 grad_sum = 0._real32
1151 do j = dest(1,2), dest(2,2)
1152 idx_out = i + (j-1) * output_h + &
1153 (k-1) * output_h * output_w + &
1154 (m - 1) * output_h * output_w * output_d
1155 grad_sum = grad_sum + upstream_grad(idx_out, s)
1156 end do
1157 output(idx_in, s) = output(idx_in, s) + grad_sum
1158 end do
1159 end do
1160 end do
1161 end do
1162 case(3) ! Face perpendicular to dimension 3
1163 do s = 1, input_shape(5)
1164 do m = 1, input_shape(4)
1165 do j = dest(1,2), dest(2,2)
1166 do i = dest(1,1), dest(2,1)
1167 idx_in = i - dest(1,1) + 1 + &
1168 ( j - dest(1,2) ) * input_h + &
1169 (m - 1) * input_h * input_w * input_d
1170 grad_sum = 0._real32
1171 do k = dest(1,3), dest(2,3)
1172 idx_out = i + (j-1) * output_h + &
1173 (k-1) * output_h * output_w + &
1174 (m - 1) * output_h * output_w * output_d
1175 grad_sum = grad_sum + upstream_grad(idx_out, s)
1176 end do
1177 output(idx_in, s) = output(idx_in, s) + grad_sum
1178 end do
1179 end do
1180 end do
1181 end do
1182 end select
1183 end do
1184 end select
1185
1186 end subroutine accumulate_face_gradients_3d_val
1187 !###############################################################################
1188
1189
1190 !###############################################################################
1191 module function pad1d(input, facets, pad_size, imethod) result(output)
1192 !! 1D padding operation
1193 implicit none
1194
1195 ! Arguments
1196 type(array_type), intent(in), target :: input
1197 type(facets_type), intent(in) :: facets
1198 integer, intent(in) :: pad_size
1199 integer, intent(in) :: imethod
1200 type(array_type), pointer :: output
1201
1202 ! Local variables
1203 integer :: i, m, s
1204 integer :: idx_in, idx_out
1205 integer :: input_size, output_size
1206 integer, dimension(3) :: output_shape
1207
1208 input_size = input%shape(1)
1209 output_size = input_size + 2 * pad_size
1210
1211 output_shape = [ output_size, input%shape(2), size(input%val, dim=2) ]
1212 output => input%create_result(array_shape = output_shape)
1213
1214 ! save the facet values to indices and adj_ja
1215 allocate(output%indices(2 + facets%num))
1216 output%indices(1) = imethod
1217 output%indices(2) = pad_size
1218 output%indices(3) = facets%num
1219 allocate(output%adj_ja(2, 2 * facets%num))
1220 do i = 1, facets%num
1221 output%adj_ja(1,(i-1)*2 + 1) = facets%orig_bound(1,1,i)
1222 output%adj_ja(2,(i-1)*2 + 1) = facets%dest_bound(1,1,i)
1223 output%adj_ja(1,(i-1)*2 + 2) = facets%orig_bound(2,1,i)
1224 output%adj_ja(2,(i-1)*2 + 2) = facets%dest_bound(2,1,i)
1225 end do
1226
1227 ! Initialise with pad_value
1228 output%val = 0._real32
1229
1230 ! Copy input into the correct location in output
1231 do concurrent( &
1232 s = 1:output_shape(3), &
1233 m = 1:output_shape(2), &
1234 i = 1:input_size)
1235 idx_in = i + (m-1) * input_size
1236 idx_out = i + pad_size + (m-1) * output_size
1237 output%val(idx_out, s) = input%val(idx_in, s)
1238 end do
1239
1240 if(output%indices(1) .ge. 3 .and. output%indices(1) .le. 5)then
1241 call fill_edge_region_1d( input, output )
1242 end if
1243
1244
1245 output%get_partial_left => get_partial_pad1d
1246 output%get_partial_left_val => get_partial_pad1d_val
1247 if(input%requires_grad)then
1248 output%requires_grad = .true.
1249 output%is_forward = input%is_forward
1250 output%operation = 'pad'
1251 output%left_operand => input
1252 end if
1253
1254 end function pad1d
1255 !-------------------------------------------------------------------------------
1256 function get_partial_pad1d(this, upstream_grad) result(output)
1257 !! Get the partial derivative for the pad1d operation
1258 implicit none
1259
1260 ! Arguments
1261 class(array_type), intent(inout) :: this
1262 type(array_type), intent(in) :: upstream_grad
1263 type(array_type) :: output
1264
1265 ! Local variables
1266 integer, dimension(3) :: input_shape
1267
1268 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
1269 call output%allocate(array_shape = input_shape)
1270 output%indices = this%indices
1271 output%adj_ja = this%adj_ja
1272
1273 call this%get_partial_left_val(upstream_grad%val, output%val)
1274
1275 end function get_partial_pad1d
1276 !-------------------------------------------------------------------------------
1277 pure subroutine get_partial_pad1d_val(this, upstream_grad, output)
1278 !! Get the partial derivative for the pad1d operation - raw array version
1279 implicit none
1280
1281 ! Arguments
1282 class(array_type), intent(in) :: this
1283 real(real32), dimension(:,:), intent(in) :: upstream_grad
1284 real(real32), dimension(:,:), intent(out) :: output
1285
1286 ! Local variables
1287 integer :: i, m, s
1288 integer :: idx_in, idx_out
1289 integer :: input_size, output_size
1290 integer :: num_samples, num_features
1291 integer, dimension(3) :: input_shape
1292
1293 input_shape = [ this%left_operand%shape, size(upstream_grad, dim=2) ]
1294 num_samples = input_shape(3)
1295 num_features = input_shape(2)
1296 input_size = input_shape(1)
1297 output_size = input_size + 2 * this%indices(2)
1298
1299 output = 0._real32
1300
1301 ! Main gradient extraction
1302 do concurrent( &
1303 s = 1:num_samples, &
1304 m = 1:num_features, &
1305 i = 1:input_size)
1306 idx_in = i + (m-1) * input_size
1307 idx_out = i + this%indices(2) + (m-1) * output_size
1308 output(idx_in, s) = upstream_grad(idx_out, s)
1309 end do
1310
1311 ! Handle edge gradients for special padding modes
1312 if(this%indices(1) .ge. 3 .and. this%indices(1) .le. 5)then
1313 call accumulate_edge_gradients_1d_val( &
1314 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1315 )
1316 end if
1317
1318 end subroutine get_partial_pad1d_val
1319 !###############################################################################
1320
1321
1322 !###############################################################################
1323 module function pad2d(input, facets, pad_size, imethod) result(output)
1324 !! 2D padding operation
1325 implicit none
1326
1327 ! Arguments
1328 type(array_type), intent(in), target :: input
1329 type(facets_type), dimension(2), intent(in) :: facets
1330 integer, dimension(2), intent(in) :: pad_size
1331 integer, intent(in) :: imethod
1332 type(array_type), pointer :: output
1333
1334 ! Local variables
1335 integer :: i, j, m, s
1336 integer :: idx_in, idx_out, idx_shift
1337 integer :: input_size_h, input_size_w, num_channels
1338 integer :: output_size_h, output_size_w
1339 integer, dimension(4) :: output_shape
1340
1341 input_size_h = input%shape(1)
1342 input_size_w = input%shape(2)
1343 num_channels = input%shape(3)
1344 output_size_h = input_size_h + 2 * pad_size(1)
1345 output_size_w = input_size_w + 2 * pad_size(2)
1346
1347 output_shape = [ &
1348 output_size_h, output_size_w, num_channels, size(input%val, dim=2) &
1349 ]
1350 output => input%create_result(array_shape = output_shape)
1351
1352 ! save the facet values to indices and adj_ja
1353 allocate(output%indices(3 + 2 + sum( facets(:)%num )))
1354 output%indices(1) = imethod
1355 output%indices(2) = pad_size(1)
1356 output%indices(3) = pad_size(2)
1357 output%indices(4) = facets(1)%num
1358 output%indices(5) = facets(2)%num
1359 output%indices(6:5 + facets(1)%num) = [(facets(1)%dim(i), i=1, facets(1)%num)]
1360 output%indices(6 + facets(1)%num:5 + facets(1)%num + facets(2)%num) = &
1361 [(facets(2)%dim(i), i=1, facets(2)%num)]
1362 allocate(output%adj_ja(2, 4 * ( facets(1)%num + facets(2)%num )))
1363 ! Edges (1D faces)
1364 do i = 1, facets(1)%num
1365 output%adj_ja(1,(i-1)*4 + 1) = facets(1)%orig_bound(1,1,i)
1366 output%adj_ja(2,(i-1)*4 + 1) = facets(1)%dest_bound(1,1,i)
1367 output%adj_ja(1,(i-1)*4 + 2) = facets(1)%orig_bound(2,1,i)
1368 output%adj_ja(2,(i-1)*4 + 2) = facets(1)%dest_bound(2,1,i)
1369 output%adj_ja(1,(i-1)*4 + 3) = facets(1)%orig_bound(1,2,i)
1370 output%adj_ja(2,(i-1)*4 + 3) = facets(1)%dest_bound(1,2,i)
1371 output%adj_ja(1,(i-1)*4 + 4) = facets(1)%orig_bound(2,2,i)
1372 output%adj_ja(2,(i-1)*4 + 4) = facets(1)%dest_bound(2,2,i)
1373 end do
1374 idx_shift = facets(1)%num * 4
1375 ! Corners (2D edges)
1376 do i = 1, facets(2)%num
1377 output%adj_ja(1,(i-1)*4 + 1 + idx_shift) = facets(2)%orig_bound(1,1,i)
1378 output%adj_ja(2,(i-1)*4 + 1 + idx_shift) = facets(2)%dest_bound(1,1,i)
1379 output%adj_ja(1,(i-1)*4 + 2 + idx_shift) = facets(2)%orig_bound(2,1,i)
1380 output%adj_ja(2,(i-1)*4 + 2 + idx_shift) = facets(2)%dest_bound(2,1,i)
1381 output%adj_ja(1,(i-1)*4 + 3 + idx_shift) = facets(2)%orig_bound(1,2,i)
1382 output%adj_ja(2,(i-1)*4 + 3 + idx_shift) = facets(2)%dest_bound(1,2,i)
1383 output%adj_ja(1,(i-1)*4 + 4 + idx_shift) = facets(2)%orig_bound(2,2,i)
1384 output%adj_ja(2,(i-1)*4 + 4 + idx_shift) = facets(2)%dest_bound(2,2,i)
1385 end do
1386
1387 ! Initialise with zero
1388 output%val = 0._real32
1389
1390 ! Copy input into the correct location in output
1391 do concurrent( &
1392 s = 1:output_shape(4), &
1393 m = 1:num_channels, &
1394 j = 1:input_size_w, &
1395 i = 1:input_size_h)
1396 idx_in = i + (j-1) * input_size_h + (m-1) * input_size_h * input_size_w
1397 idx_out = (i + pad_size(1)) + (j + pad_size(2) - 1) * output_size_h + &
1398 (m-1) * output_size_h * output_size_w
1399 output%val(idx_out, s) = input%val(idx_in, s)
1400 end do
1401
1402 if(output%indices(1) .ge. 3 .and. output%indices(1) .le. 5)then
1403 call fill_corner_region_2d( input, output )
1404 call fill_edge_region_2d( input, output )
1405 end if
1406
1407 output%get_partial_left => get_partial_pad2d
1408 output%get_partial_left_val => get_partial_pad2d_val
1409 if(input%requires_grad)then
1410 output%requires_grad = .true.
1411 output%is_forward = input%is_forward
1412 output%operation = 'pad'
1413 output%left_operand => input
1414 end if
1415
1416 end function pad2d
1417 !-------------------------------------------------------------------------------
1418 function get_partial_pad2d(this, upstream_grad) result(output)
1419 !! Get the partial derivative for the pad2d operation
1420 implicit none
1421
1422 ! Arguments
1423 class(array_type), intent(inout) :: this
1424 type(array_type), intent(in) :: upstream_grad
1425 type(array_type) :: output
1426
1427 ! Local variables
1428 integer, dimension(4) :: input_shape
1429
1430 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
1431 call output%allocate(array_shape = input_shape)
1432 output%indices = this%indices
1433 output%adj_ja = this%adj_ja
1434
1435 call this%get_partial_left_val(upstream_grad%val, output%val)
1436
1437 end function get_partial_pad2d
1438 !-------------------------------------------------------------------------------
1439 pure subroutine get_partial_pad2d_val(this, upstream_grad, output)
1440 !! Get the partial derivative for the pad2d operation - raw array version
1441 implicit none
1442
1443 ! Arguments
1444 class(array_type), intent(in) :: this
1445 real(real32), dimension(:,:), intent(in) :: upstream_grad
1446 real(real32), dimension(:,:), intent(out) :: output
1447
1448 ! Local variables
1449 integer :: i, j, m, s
1450 integer :: idx_in, idx_out
1451 integer :: input_size_h, input_size_w, num_channels
1452 integer :: output_size_h, output_size_w
1453 integer :: num_samples
1454 integer, dimension(4) :: input_shape
1455
1456 input_shape = [ this%left_operand%shape, size(upstream_grad, dim=2) ]
1457 num_samples = input_shape(4)
1458 input_size_h = input_shape(1)
1459 input_size_w = input_shape(2)
1460 num_channels = input_shape(3)
1461 output_size_h = input_size_h + 2 * this%indices(2)
1462 output_size_w = input_size_w + 2 * this%indices(3)
1463
1464 output = 0._real32
1465
1466 ! Main gradient extraction
1467 do concurrent( &
1468 s = 1:num_samples, &
1469 m = 1:num_channels, &
1470 j = 1:input_size_w, &
1471 i = 1:input_size_h)
1472 idx_in = i + (j-1) * input_size_h + (m-1) * input_size_h * input_size_w
1473 idx_out = (i + this%indices(2)) + &
1474 (j + this%indices(3) - 1) * output_size_h + &
1475 (m-1) * output_size_h * output_size_w
1476 output(idx_in, s) = upstream_grad(idx_out, s)
1477 end do
1478
1479 ! Handle corner and edge gradients for special padding modes
1480 if(this%indices(1) .ge. 3 .and. this%indices(1) .le. 5)then
1481 call accumulate_corner_gradients_2d_val( &
1482 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1483 )
1484 call accumulate_edge_gradients_2d_val( &
1485 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1486 )
1487 end if
1488
1489 end subroutine get_partial_pad2d_val
1490 !###############################################################################
1491
1492
1493 !###############################################################################
1494 module function pad3d(input, facets, pad_size, imethod) result(output)
1495 !! 3D padding operation
1496 implicit none
1497
1498 ! Arguments
1499 type(array_type), intent(in), target :: input
1500 type(facets_type), dimension(3), intent(in) :: facets
1501 integer, dimension(3), intent(in) :: pad_size
1502 integer, intent(in) :: imethod
1503 type(array_type), pointer :: output
1504
1505 ! Local variables
1506 integer :: i, j, k, m, s
1507 integer :: idx_in, idx_out, idx_shift
1508 integer :: input_size_h, input_size_w, input_size_d, num_channels
1509 integer :: output_size_h, output_size_w, output_size_d
1510 integer, dimension(5) :: output_shape
1511
1512 input_size_h = input%shape(1)
1513 input_size_w = input%shape(2)
1514 input_size_d = input%shape(3)
1515 num_channels = input%shape(4)
1516 output_size_h = input_size_h + 2 * pad_size(1)
1517 output_size_w = input_size_w + 2 * pad_size(2)
1518 output_size_d = input_size_d + 2 * pad_size(3)
1519
1520 output_shape = [ output_size_h, output_size_w, output_size_d, num_channels, &
1521 size(input%val, dim=2) ]
1522 output => input%create_result(array_shape = output_shape)
1523
1524 ! save the facet values to indices and adj_ja
1525 allocate(output%indices(4 + 3 + sum( facets(:)%num )))
1526 output%indices(1) = imethod
1527 output%indices(2) = pad_size(1)
1528 output%indices(3) = pad_size(2)
1529 output%indices(4) = pad_size(3)
1530 output%indices(5) = facets(1)%num
1531 output%indices(6) = facets(2)%num
1532 output%indices(7) = facets(3)%num
1533 output%indices(8:7 + facets(1)%num) = [(facets(1)%dim(i), i=1, facets(1)%num)]
1534 output%indices(8 + facets(1)%num:7 + facets(1)%num + facets(2)%num) = &
1535 [(facets(2)%dim(i), i=1, facets(2)%num)]
1536 output%indices(8 + facets(1)%num + facets(2)%num:7 + &
1537 facets(1)%num + facets(2)%num + facets(3)%num) = &
1538 [(facets(3)%dim(i), i=1, facets(3)%num)]
1539 allocate(output%adj_ja(2, 6 * (facets(1)%num + facets(2)%num + facets(3)%num)))
1540 ! Edges (1D edges)
1541 do i = 1, facets(1)%num
1542 output%adj_ja(1,(i-1)*6 + 1 : (i-1)*6 + 2) = facets(1)%orig_bound(1:2,1,i)
1543 output%adj_ja(1,(i-1)*6 + 3 : (i-1)*6 + 4) = facets(1)%orig_bound(1:2,2,i)
1544 output%adj_ja(1,(i-1)*6 + 5 : (i-1)*6 + 6) = facets(1)%orig_bound(1:2,3,i)
1545 output%adj_ja(2,(i-1)*6 + 1 : (i-1)*6 + 2) = facets(1)%dest_bound(1:2,1,i)
1546 output%adj_ja(2,(i-1)*6 + 3 : (i-1)*6 + 4) = facets(1)%dest_bound(1:2,2,i)
1547 output%adj_ja(2,(i-1)*6 + 5 : (i-1)*6 + 6) = facets(1)%dest_bound(1:2,3,i)
1548 end do
1549 idx_shift = facets(1)%num * 6
1550 ! Faces (2D faces)
1551 do i = 1, facets(2)%num
1552 output%adj_ja(1,(i-1)*6 + 1 + idx_shift : (i-1)*6 + 2 + idx_shift) = &
1553 facets(2)%orig_bound(1:2,1,i)
1554 output%adj_ja(1,(i-1)*6 + 3 + idx_shift : (i-1)*6 + 4 + idx_shift) = &
1555 facets(2)%orig_bound(1:2,2,i)
1556 output%adj_ja(1,(i-1)*6 + 5 + idx_shift : (i-1)*6 + 6 + idx_shift) = &
1557 facets(2)%orig_bound(1:2,3,i)
1558 output%adj_ja(2,(i-1)*6 + 1 + idx_shift : (i-1)*6 + 2 + idx_shift) = &
1559 facets(2)%dest_bound(1:2,1,i)
1560 output%adj_ja(2,(i-1)*6 + 3 + idx_shift : (i-1)*6 + 4 + idx_shift) = &
1561 facets(2)%dest_bound(1:2,2,i)
1562 output%adj_ja(2,(i-1)*6 + 5 + idx_shift : (i-1)*6 + 6 + idx_shift) = &
1563 facets(2)%dest_bound(1:2,3,i)
1564 end do
1565 idx_shift = idx_shift + facets(2)%num * 6
1566 ! Corners (3D corners)
1567 do i = 1, facets(3)%num
1568 output%adj_ja(1,(i-1)*6 + 1 + idx_shift : (i-1)*6 + 2 + idx_shift) = &
1569 facets(3)%orig_bound(1:2,1,i)
1570 output%adj_ja(1,(i-1)*6 + 3 + idx_shift : (i-1)*6 + 4 + idx_shift) = &
1571 facets(3)%orig_bound(1:2,2,i)
1572 output%adj_ja(1,(i-1)*6 + 5 + idx_shift : (i-1)*6 + 6 + idx_shift) = &
1573 facets(3)%orig_bound(1:2,3,i)
1574 output%adj_ja(2,(i-1)*6 + 1 + idx_shift : (i-1)*6 + 2 + idx_shift) = &
1575 facets(3)%dest_bound(1:2,1,i)
1576 output%adj_ja(2,(i-1)*6 + 3 + idx_shift : (i-1)*6 + 4 + idx_shift) = &
1577 facets(3)%dest_bound(1:2,2,i)
1578 output%adj_ja(2,(i-1)*6 + 5 + idx_shift : (i-1)*6 + 6 + idx_shift) = &
1579 facets(3)%dest_bound(1:2,3,i)
1580 end do
1581
1582 ! Initialise with zero
1583 output%val = 0._real32
1584
1585 ! Copy input into the correct location in output
1586 do concurrent( &
1587 s = 1:output_shape(5), &
1588 m = 1:num_channels, &
1589 k = 1:input_size_d, &
1590 j = 1:input_size_w, &
1591 i = 1:input_size_h)
1592 idx_in = i + (j-1) * input_size_h + (k-1) * input_size_h * input_size_w + &
1593 (m-1) * input_size_h * input_size_w * input_size_d
1594 idx_out = (i + pad_size(1)) + &
1595 (j + pad_size(2) - 1) * output_size_h + &
1596 (k + pad_size(3) - 1) * output_size_h * output_size_w + &
1597 (m-1) * output_size_h * output_size_w * output_size_d
1598 output%val(idx_out, s) = input%val(idx_in, s)
1599 end do
1600
1601 if(output%indices(1) .ge. 3 .and. output%indices(1) .le. 5)then
1602 call fill_corner_region_3d( input, output )
1603 call fill_edge_region_3d( input, output )
1604 call fill_face_region_3d( input, output )
1605 end if
1606
1607 output%get_partial_left => get_partial_pad3d
1608 output%get_partial_left_val => get_partial_pad3d_val
1609 if(input%requires_grad)then
1610 output%requires_grad = .true.
1611 output%is_forward = input%is_forward
1612 output%operation = 'pad'
1613 output%left_operand => input
1614 end if
1615
1616 end function pad3d
1617 !-------------------------------------------------------------------------------
1618 function get_partial_pad3d(this, upstream_grad) result(output)
1619 !! Get the partial derivative for the pad3d operation
1620 implicit none
1621
1622 ! Arguments
1623 class(array_type), intent(inout) :: this
1624 type(array_type), intent(in) :: upstream_grad
1625 type(array_type) :: output
1626
1627 ! Local variables
1628 integer, dimension(5) :: input_shape
1629
1630 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
1631 call output%allocate(array_shape = input_shape)
1632 output%indices = this%indices
1633 output%adj_ja = this%adj_ja
1634
1635 call this%get_partial_left_val(upstream_grad%val, output%val)
1636
1637 end function get_partial_pad3d
1638 !-------------------------------------------------------------------------------
1639 pure subroutine get_partial_pad3d_val(this, upstream_grad, output)
1640 !! Get the partial derivative for the pad3d operation - raw array version
1641 implicit none
1642
1643 ! Arguments
1644 class(array_type), intent(in) :: this
1645 real(real32), dimension(:,:), intent(in) :: upstream_grad
1646 real(real32), dimension(:,:), intent(out) :: output
1647
1648 ! Local variables
1649 integer :: i, j, k, m, s
1650 integer :: idx_in, idx_out
1651 integer :: input_size_h, input_size_w, input_size_d, num_channels
1652 integer :: output_size_h, output_size_w, output_size_d
1653 integer :: num_samples
1654 integer, dimension(5) :: input_shape
1655
1656 input_shape = [ this%left_operand%shape, size(upstream_grad, dim=2) ]
1657 num_samples = input_shape(5)
1658 input_size_h = input_shape(1)
1659 input_size_w = input_shape(2)
1660 input_size_d = input_shape(3)
1661 num_channels = input_shape(4)
1662 output_size_h = input_size_h + 2 * this%indices(2)
1663 output_size_w = input_size_w + 2 * this%indices(3)
1664 output_size_d = input_size_d + 2 * this%indices(4)
1665
1666 output = 0._real32
1667
1668 ! Main gradient extraction
1669 do concurrent( &
1670 s = 1:num_samples, &
1671 m = 1:num_channels, &
1672 k = 1:input_size_d, &
1673 j = 1:input_size_w, &
1674 i = 1:input_size_h)
1675 idx_in = i + (j-1) * input_size_h + &
1676 (k-1) * input_size_h * input_size_w + &
1677 (m-1) * input_size_h * input_size_w * input_size_d
1678 idx_out = (i + this%indices(2)) + &
1679 (j + this%indices(3) - 1) * output_size_h + &
1680 (k + this%indices(4) - 1) * output_size_h * output_size_w + &
1681 (m-1) * output_size_h * output_size_w * output_size_d
1682 output(idx_in, s) = upstream_grad(idx_out, s)
1683 end do
1684
1685 ! Handle corner, edge, and face gradients for special padding modes
1686 if(this%indices(1) .ge. 3 .and. this%indices(1) .le. 5)then
1687 call accumulate_corner_gradients_3d_val( &
1688 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1689 )
1690 call accumulate_edge_gradients_3d_val( &
1691 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1692 )
1693 call accumulate_face_gradients_3d_val( &
1694 upstream_grad, output, input_shape, this%indices, this%adj_ja &
1695 )
1696 end if
1697
1698 end subroutine get_partial_pad3d_val
1699 !###############################################################################
1700
1701 end submodule athena__diffstruc_extd_submodule_pad
1702