GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_merge.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_merge
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function merge_scalar_over_channels(tsource, fsource, mask) result(output)
8 !! 1D average pooling operation
9 implicit none
10
11 ! Arguments
12 class(array_type), intent(in), target :: tsource
13 real(real32), intent(in) :: fsource
14 logical, dimension(:,:), intent(in) :: mask
15 type(array_type), pointer :: output
16
17 ! Local variables
18 integer :: i, m, s
19 integer :: num_elements, num_dims
20
21
22 output => tsource%create_result()
23 num_dims = size(tsource%shape)
24 num_elements = product(tsource%shape(1:num_dims - 1))
25 do concurrent(s = 1:size(tsource%val,2), m = 1: tsource%shape(num_dims))
26 do concurrent(i=1:num_elements)
27 if(mask(i,1)) then
28 output%val(i + (m-1) * num_elements,s) = tsource%val(i,s)
29 else
30 output%val(i + (m-1) * num_elements,s) = fsource
31 end if
32 end do
33 end do
34 output%mask = mask
35
36 output%get_partial_left => get_partial_merge_scalar_over_channels
37 output%get_partial_left_val => get_partial_merge_scalar_over_channels_val
38 if(tsource%requires_grad) then
39 output%requires_grad = .true.
40 output%is_forward = tsource%is_forward
41 output%operation = 'merge_over_channels'
42 output%left_operand => tsource
43 end if
44
45 end function merge_scalar_over_channels
46 !-------------------------------------------------------------------------------
47 function get_partial_merge_scalar_over_channels(this, upstream_grad) result(output)
48 implicit none
49 class(array_type), intent(inout) :: this
50 type(array_type), intent(in) :: upstream_grad
51 type(array_type) :: output
52
53 output = merge_scalar_over_channels(upstream_grad, 0._real32, this%mask)
54
55 end function get_partial_merge_scalar_over_channels
56 !-------------------------------------------------------------------------------
57 pure subroutine get_partial_merge_scalar_over_channels_val( &
58 this, upstream_grad, output &
59 )
60 implicit none
61 class(array_type), intent(in) :: this
62 real(real32), dimension(:,:), intent(in) :: upstream_grad
63 real(real32), dimension(:,:), intent(out) :: output
64
65 integer :: i, m, s
66 integer :: num_elements, num_dims, num_channels
67
68 num_dims = size(this%left_operand%shape)
69 num_elements = product(this%left_operand%shape(1:num_dims - 1))
70 num_channels = this%left_operand%shape(num_dims)
71
72 do concurrent(s = 1:size(upstream_grad,2), m = 1: num_channels)
73 do concurrent(i=1:num_elements)
74 if(this%mask(i,1)) then
75 output(i + (m-1) * num_elements,s) = upstream_grad(i,s)
76 else
77 output(i + (m-1) * num_elements,s) = 0._real32
78 end if
79 end do
80 end do
81
82 end subroutine get_partial_merge_scalar_over_channels_val
83 !###############################################################################
84
85 end submodule athena__diffstruc_extd_submodule_merge
86