GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_diffstruc_extd_sub_pool.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 279 334 83.5%
Functions: 0 0 -%
Branches: 843 1518 55.5%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_pool
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 2 module function avgpool1d(input, pool_size, stride) result(output)
8 !! 1D average pooling operation
9 implicit none
10
11 ! Arguments
12 type(array_type), intent(in), target :: input
13 integer, intent(in) :: pool_size
14 integer, intent(in) :: stride
15 type(array_type), pointer :: output
16
17 ! Local variables
18 integer :: i, m, s
19 integer :: stride_idx, idx
20 integer, dimension(3) :: output_shape
21
22 output_shape = [ &
23 (input%shape(1) - pool_size) / stride + 1, &
24 input%shape(2), &
25
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 8 taken 6 times.
✓ Branch 9 taken 2 times.
8 size(input%val, dim=2)]
26 2 output => input%create_result(array_shape = output_shape)
27 do concurrent(&
28 s = 1:output_shape(3), &
29 m = 1:output_shape(2), &
30 2 i = 1:output_shape(1))
31
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 stride_idx = (i - 1) * stride + (m - 1) * input%shape(1)
32 18 idx = i + (m - 1) * output_shape(1)
33 output%val(idx, s) = sum( &
34 input%val( stride_idx + 1 : stride_idx + pool_size, s ) &
35
18/28
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 18 times.
✓ Branch 3 taken 6 times.
✓ Branch 4 taken 18 times.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 18 times.
✓ Branch 18 taken 54 times.
✓ Branch 19 taken 18 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 18 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 18 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 18 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 18 times.
98 ) / pool_size
36 end do
37
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
2 allocate(output%adj_ja(1,2))
38
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 output%adj_ja(1,1) = pool_size
39
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 output%adj_ja(1,2) = stride
40
41 2 output%get_partial_left => get_partial_avgpool1d
42 2 output%get_partial_left_val => get_partial_avgpool1d_val
43
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 if(input%requires_grad)then
44 1 output%requires_grad = .true.
45 1 output%is_forward = input%is_forward
46 1 output%operation = 'avgpool'
47 1 output%left_operand => input
48 end if
49
50 2 end function avgpool1d
51 !-------------------------------------------------------------------------------
52 function get_partial_avgpool1d(this, upstream_grad) result(output)
53 !! Get the partial derivative for average pooling
54 implicit none
55
56 ! Arguments
57 class(array_type), intent(inout) :: this
58 type(array_type), intent(in) :: upstream_grad
59 type(array_type) :: output
60
61 call output%allocate(array_shape = &
62 [ this%left_operand%shape, size(this%val, dim=2) ] &
63 )
64 call this%get_partial_left_val(upstream_grad%val, output%val)
65
66 end function get_partial_avgpool1d
67 !-------------------------------------------------------------------------------
68
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_avgpool1d_val(this, upstream_grad, output)
69 !! Optimised backward pass for 1D average pooling
70 implicit none
71
72 ! Arguments
73 class(array_type), intent(in) :: this
74 real(real32), dimension(:,:), intent(in) :: upstream_grad
75 real(real32), dimension(:,:), intent(out) :: output
76
77 ! Local variables
78 integer :: i, m, s, p
79 integer :: base_idx, out_idx, input_h
80 real(real32) :: pool_norm, grad_val
81 integer, dimension(3) :: input_shape
82 integer, dimension(1) :: pool_size, stride
83
84 ! Unpack parameters
85
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 2 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 3 times.
✓ Branch 15 taken 1 times.
6 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
86
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 pool_size(1) = this%adj_ja(1,1)
87
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 stride(1) = this%adj_ja(1,2)
88 1 input_h = input_shape(1)
89
90
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 27 times.
✓ Branch 21 taken 1 times.
29 output = 0._real32
91
92 1 pool_norm = 1.0_real32 / real(pool_size(1), real32)
93
94 ! Parallelised over batch and spatial/channel dimensions
95 2 do concurrent(s = 1:input_shape(3), m = 1:this%shape(2), &
96
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 i = 1:this%shape(1))
97
98 ! Compute indices once
99 9 base_idx = (i - 1) * stride(1) + (m - 1) * input_h
100
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
9 out_idx = i + (m - 1) * this%shape(1)
101
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
9 grad_val = upstream_grad(out_idx, s) * pool_norm
102
103 ! Distribute gradient over pooling window
104
8/8
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
✓ Branch 4 taken 9 times.
✓ Branch 5 taken 9 times.
✓ Branch 6 taken 27 times.
✓ Branch 7 taken 9 times.
58 do p = 0, pool_size(1) - 1
105
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 27 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 27 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 27 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 27 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 27 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 27 times.
36 output(base_idx + p + 1, s) = output(base_idx + p + 1, s) + grad_val
106 end do
107 end do
108
109 1 end subroutine get_partial_avgpool1d_val
110 !###############################################################################
111
112
113 !###############################################################################
114 2 module function avgpool2d(input, pool_size, stride) result(output)
115 !! 2D average pooling operation
116 implicit none
117
118 ! Arguments
119 type(array_type), intent(in), target :: input
120 integer, dimension(2), intent(in) :: pool_size
121 integer, dimension(2), intent(in) :: stride
122 type(array_type), pointer :: output
123
124 ! Local variables
125 integer :: i, j, m, s, i_step, j_step
126 integer :: stride_idx, idx, multiplier
127 integer :: channel_size_in, channel_size_out
128 real(real32) :: pool_sum, pool_norm
129 integer, dimension(4) :: output_shape
130
131 output_shape = [ &
132 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
133 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
134 input%shape(3), &
135
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✓ Branch 12 taken 8 times.
✓ Branch 13 taken 2 times.
10 size(input%val, dim=2)]
136 2 output => input%create_result(array_shape = output_shape)
137 2 pool_norm = 1.0_real32 / real(pool_size(1) * pool_size(2), real32)
138
139 ! Pre-compute as integers
140
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 channel_size_in = input%shape(1) * input%shape(2)
141 2 channel_size_out = output_shape(1) * output_shape(2)
142
143 do concurrent(&
144 s = 1:output_shape(4), &
145 m = 1:output_shape(3), &
146 j = 1:output_shape(2), &
147 2 i = 1:output_shape(1))
148
149 ! Compute indices once
150 stride_idx = (i-1)*stride(1) + &
151 ((j-1)*stride(2)) * input%shape(1) + &
152
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 54 times.
54 (m-1) * channel_size_in
153 54 idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out
154
155 54 pool_sum = 0._real32
156
2/2
✓ Branch 0 taken 162 times.
✓ Branch 1 taken 54 times.
216 do j_step = 0, pool_size(2)-1
157
2/2
✓ Branch 0 taken 486 times.
✓ Branch 1 taken 162 times.
702 do i_step = 0, pool_size(1)-1
158 pool_sum = pool_sum + &
159
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 486 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 486 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 486 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 486 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 486 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 486 times.
648 input%val(stride_idx + i_step + j_step * input%shape(1) + 1, s)
160 end do
161 end do
162
12/16
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 18 times.
✓ Branch 3 taken 6 times.
✓ Branch 4 taken 54 times.
✓ Branch 5 taken 18 times.
✓ Branch 6 taken 54 times.
✓ Branch 7 taken 54 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 54 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 54 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 54 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 54 times.
134 output%val(idx, s) = pool_sum * pool_norm
163 end do
164
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
2 allocate(output%adj_ja(2,2))
165
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 4 times.
✓ Branch 22 taken 2 times.
6 output%adj_ja(:,1) = pool_size
166
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 4 times.
✓ Branch 22 taken 2 times.
6 output%adj_ja(:,2) = stride
167
168 2 output%get_partial_left => get_partial_avgpool2d
169 2 output%get_partial_left_val => get_partial_avgpool2d_val
170
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 if(input%requires_grad)then
171 1 output%requires_grad = .true.
172 1 output%is_forward = input%is_forward
173 1 output%operation = 'avgpool'
174 1 output%left_operand => input
175 end if
176
177 2 end function avgpool2d
178 !-------------------------------------------------------------------------------
179 function get_partial_avgpool2d(this, upstream_grad) result(output)
180 !! Get the partial derivative for average pooling
181 implicit none
182
183 ! Arguments
184 class(array_type), intent(inout) :: this
185 type(array_type), intent(in) :: upstream_grad
186 type(array_type) :: output
187
188 call output%allocate(array_shape = &
189 [ this%left_operand%shape, size(this%val, dim=2) ] &
190 )
191 call this%get_partial_left_val(upstream_grad%val, output%val)
192
193 end function get_partial_avgpool2d
194 !-------------------------------------------------------------------------------
195
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_avgpool2d_val(this, upstream_grad, output)
196 !! Optimised backward pass for 2D average pooling
197 implicit none
198
199 ! Arguments
200 class(array_type), intent(in) :: this
201 real(real32), dimension(:,:), intent(in) :: upstream_grad
202 real(real32), dimension(:,:), intent(out) :: output
203
204 ! Local variables
205 integer :: i, j, m, s
206 integer :: i_step, j_step
207 integer :: base_idx, in_idx, out_idx, input_h
208 integer :: channel_size_in, channel_size_out
209 real(real32) :: pool_norm, grad_val
210 integer, dimension(4) :: input_shape
211 integer, dimension(2) :: pool_size, stride
212
213 ! Unpack parameters
214
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 3 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 4 times.
✓ Branch 15 taken 1 times.
8 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
215
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 2 times.
✓ Branch 22 taken 1 times.
3 pool_size = this%adj_ja(:,1)
216
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 2 times.
✓ Branch 22 taken 1 times.
3 stride = this%adj_ja(:,2)
217 1 input_h = input_shape(1)
218 1 channel_size_in = input_h * input_shape(2)
219
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 channel_size_out = this%shape(1) * this%shape(2)
220
221
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 243 times.
✓ Branch 21 taken 1 times.
245 output = 0._real32
222
223 1 pool_norm = 1.0_real32 / real(pool_size(1) * pool_size(2), real32)
224
225 do concurrent( &
226 s = 1:input_shape(4), &
227 2 m = 1:this%shape(3), &
228 2 j = 1:this%shape(2), &
229
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 i = 1:this%shape(1))
230
231 ! Compute indices once
232 base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
233 27 (m-1) * channel_size_in
234
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
27 out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
235
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 27 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 27 times.
27 grad_val = upstream_grad(out_idx, s) * pool_norm
236
237 ! Distribute gradient over pooling window
238
10/10
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
✓ Branch 4 taken 27 times.
✓ Branch 5 taken 9 times.
✓ Branch 6 taken 27 times.
✓ Branch 7 taken 27 times.
✓ Branch 8 taken 81 times.
✓ Branch 9 taken 27 times.
175 do j_step = 0, pool_size(2) - 1
239
2/2
✓ Branch 0 taken 243 times.
✓ Branch 1 taken 81 times.
351 do i_step = 0, pool_size(1) - 1
240 243 in_idx = base_idx + i_step + j_step * input_h + 1
241
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 243 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 243 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 243 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 243 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 243 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 243 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 243 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 243 times.
324 output(in_idx, s) = output(in_idx, s) + grad_val
242 end do
243 end do
244 end do
245
246 1 end subroutine get_partial_avgpool2d_val
247 !###############################################################################
248
249
250 !###############################################################################
251 2 module function avgpool3d(input, pool_size, stride) result(output)
252 !! 3D average pooling operation
253 implicit none
254
255 ! Arguments
256 type(array_type), intent(in), target :: input
257 integer, dimension(3), intent(in) :: pool_size
258 integer, dimension(3), intent(in) :: stride
259 type(array_type), pointer :: output
260
261 ! Local variables
262 integer :: i, j, k, m, s
263 integer :: i_step, j_step, k_step
264 integer :: stride_idx, idx
265 integer :: channel_size_in, channel_size_out
266 real(real32) :: pool_sum, pool_norm
267 integer, dimension(5) :: output_shape
268
269 ! output_shape = [H_out, W_out, D_out, C, B]
270 output_shape = [ &
271 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
272 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
273 (input%shape(3) - pool_size(3)) / stride(3) + 1, &
274 input%shape(4), &
275
10/18
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 10 times.
✓ Branch 17 taken 2 times.
12 size(input%val, dim=2) ]
276
277 2 output => input%create_result(array_shape = output_shape)
278
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 pool_norm = 1.0_real32 / real(product(pool_size), real32)
279
280 ! Pre-compute as integers
281
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
2 channel_size_in = input%shape(1) * input%shape(2) * input%shape(3)
282 2 channel_size_out = output_shape(1) * output_shape(2) * output_shape(3)
283
284 do concurrent( &
285 s = 1:output_shape(5), &
286 m = 1:output_shape(4), &
287 k = 1:output_shape(3), &
288 j = 1:output_shape(2), &
289 2 i = 1:output_shape(1))
290
291 ! Compute indices once
292 stride_idx = ((i-1)*stride(1)) + &
293 ((j-1)*stride(2)) * input%shape(1) + &
294 ((k-1)*stride(3)) * input%shape(1) * input%shape(2) + &
295
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 162 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 162 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 162 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 162 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 162 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 162 times.
162 (m-1) * channel_size_in
296 idx = i + (j-1) * output_shape(1) + &
297 (k-1) * output_shape(1)*output_shape(2) + &
298 162 (m-1) * channel_size_out
299
300 162 pool_sum = 0._real32
301
2/2
✓ Branch 0 taken 486 times.
✓ Branch 1 taken 162 times.
648 do k_step = 0, pool_size(3)-1
302
2/2
✓ Branch 0 taken 1458 times.
✓ Branch 1 taken 486 times.
2106 do j_step = 0, pool_size(2)-1
303
2/2
✓ Branch 0 taken 4374 times.
✓ Branch 1 taken 1458 times.
6318 do i_step = 0, pool_size(1)-1
304 pool_sum = pool_sum + input%val(stride_idx + i_step + &
305 j_step * input%shape(1) + &
306
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 4374 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4374 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4374 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4374 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4374 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4374 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 4374 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 4374 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 4374 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 4374 times.
5832 k_step * input%shape(1) * input%shape(2) + 1, s)
307 end do
308 end do
309 end do
310
311
14/18
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 18 times.
✓ Branch 3 taken 6 times.
✓ Branch 4 taken 54 times.
✓ Branch 5 taken 18 times.
✓ Branch 6 taken 162 times.
✓ Branch 7 taken 54 times.
✓ Branch 8 taken 162 times.
✓ Branch 9 taken 162 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 162 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 162 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 162 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 162 times.
404 output%val(idx, s) = pool_sum * pool_norm
312 end do
313
314
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
2 allocate(output%adj_ja(3,2))
315
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 2 times.
8 output%adj_ja(:,1) = pool_size
316
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 2 times.
8 output%adj_ja(:,2) = stride
317
318 2 output%get_partial_left => get_partial_avgpool3d
319 2 output%get_partial_left_val => get_partial_avgpool3d_val
320
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 if(input%requires_grad)then
321 1 output%requires_grad = .true.
322 1 output%is_forward = input%is_forward
323 1 output%operation = 'avgpool3d'
324 1 output%left_operand => input
325 end if
326
327 2 end function avgpool3d
328 !-------------------------------------------------------------------------------
329 function get_partial_avgpool3d(this, upstream_grad) result(output)
330 !! Get the partial derivative for 3D average pooling
331 implicit none
332
333 ! Arguments
334 class(array_type), intent(inout) :: this
335 type(array_type), intent(in) :: upstream_grad
336 type(array_type) :: output
337
338 call output%allocate(array_shape = &
339 [ this%left_operand%shape, size(this%val, dim=2) ] &
340 )
341 call this%get_partial_left_val(upstream_grad%val, output%val)
342
343 end function get_partial_avgpool3d
344 !-------------------------------------------------------------------------------
345
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_avgpool3d_val(this, upstream_grad, output)
346 !! Optimised backward pass for 3D average pooling
347 implicit none
348
349 ! Arguments
350 class(array_type), intent(in) :: this
351 real(real32), dimension(:,:), intent(in) :: upstream_grad
352 real(real32), dimension(:,:), intent(out) :: output
353
354 ! Local variables
355 integer :: i, j, k, m, s
356 integer :: i_step, j_step, k_step
357 integer :: base_idx, in_idx, out_idx, input_h, input_hw
358 integer :: channel_size_in, channel_size_out
359 real(real32) :: pool_norm, grad_val
360 integer, dimension(5) :: input_shape
361 integer, dimension(3) :: pool_size, stride
362
363 ! Unpack parameters
364
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 4 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 5 times.
✓ Branch 15 taken 1 times.
10 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
365
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 pool_size = this%adj_ja(:,1)
366
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 stride = this%adj_ja(:,2)
367 1 input_h = input_shape(1)
368 1 input_hw = input_h * input_shape(2)
369 1 channel_size_in = input_hw * input_shape(3)
370
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 channel_size_out = this%shape(1) * this%shape(2) * this%shape(3)
371
372
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 2187 times.
✓ Branch 21 taken 1 times.
2189 output = 0._real32
373
374
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
4 pool_norm = 1.0_real32 / real(product(pool_size), real32)
375
376 do concurrent( &
377 s = 1:input_shape(5), &
378 2 m = 1:this%shape(4), &
379 2 k = 1:this%shape(3), &
380 2 j = 1:this%shape(2), &
381
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
1 i = 1:this%shape(1))
382
383 ! Compute indices once
384 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
385 81 ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in
386 162 out_idx = i + (j-1) * this%shape(1) + &
387 324 (k-1) * this%shape(1)*this%shape(2) + &
388
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 81 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 81 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 81 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 81 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 81 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 81 times.
81 (m-1) * channel_size_out
389
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 81 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 81 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 81 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 81 times.
81 grad_val = upstream_grad(out_idx, s) * pool_norm
390
391 ! Distribute gradient over pooling window
392
12/12
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 9 times.
✓ Branch 3 taken 3 times.
✓ Branch 4 taken 27 times.
✓ Branch 5 taken 9 times.
✓ Branch 6 taken 81 times.
✓ Branch 7 taken 27 times.
✓ Branch 8 taken 81 times.
✓ Branch 9 taken 81 times.
✓ Branch 10 taken 243 times.
✓ Branch 11 taken 81 times.
526 do k_step = 0, pool_size(3)-1
393
2/2
✓ Branch 0 taken 729 times.
✓ Branch 1 taken 243 times.
1053 do j_step = 0, pool_size(2)-1
394
2/2
✓ Branch 0 taken 2187 times.
✓ Branch 1 taken 729 times.
3159 do i_step = 0, pool_size(1)-1
395 in_idx = base_idx + i_step + j_step * input_h + &
396 2187 k_step * input_hw + 1
397
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2187 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2187 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2187 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2187 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2187 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2187 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2187 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2187 times.
2916 output(in_idx, s) = output(in_idx, s) + grad_val
398 end do
399 end do
400 end do
401 end do
402
403 1 end subroutine get_partial_avgpool3d_val
404 !###############################################################################
405
406
407 !##############################################################################!
408 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
409 !##############################################################################!
410
411
412 !###############################################################################
413 1 module function maxpool1d(input, pool_size, stride) result(output)
414 !! 1D max pooling operation
415 implicit none
416
417 ! Arguments
418 type(array_type), intent(in), target :: input
419 integer, intent(in) :: pool_size
420 integer, intent(in) :: stride
421 type(array_type), pointer :: output
422
423 ! Local variables
424 integer :: i, m, s
425 integer :: stride_idx, idx
426 integer, dimension(3) :: output_shape
427
428 output_shape = [ &
429 (input%shape(1) - pool_size) / stride + 1, &
430 input%shape(2), &
431
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 8 taken 3 times.
✓ Branch 9 taken 1 times.
4 size(input%val, dim=2)]
432 1 output => input%create_result(array_shape = output_shape)
433 do concurrent(&
434 s = 1:output_shape(3), &
435 m = 1:output_shape(2), &
436 1 i = 1:output_shape(1))
437
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
24 stride_idx = (i - 1) * stride + (m - 1) * input%shape(1)
438 24 idx = i + (m - 1) * output_shape(1)
439 output%val(idx, s) = maxval( &
440 input%val( stride_idx + 1 : stride_idx + pool_size, s ) &
441
22/36
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 24 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 24 times.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 24 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 24 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 24 times.
✓ Branch 18 taken 24 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 24 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 48 times.
✓ Branch 25 taken 24 times.
✓ Branch 26 taken 3 times.
✓ Branch 27 taken 45 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 24 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 24 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 24 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 24 times.
105 )
442 end do
443
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
1 allocate(output%adj_ja(1,2))
444
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 output%adj_ja(1,1) = pool_size
445
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 output%adj_ja(1,2) = stride
446
447 1 output%get_partial_left => get_partial_maxpool1d
448 1 output%get_partial_left_val => get_partial_maxpool1d_val
449
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(input%requires_grad)then
450 1 output%requires_grad = .true.
451 1 output%is_forward = input%is_forward
452 1 output%operation = 'maxpool'
453 1 output%left_operand => input
454 end if
455
456 1 end function maxpool1d
457 !-------------------------------------------------------------------------------
458 function get_partial_maxpool1d(this, upstream_grad) result(output)
459 !! Get the partial derivative for max pooling
460 implicit none
461
462 ! Arguments
463 class(array_type), intent(inout) :: this
464 type(array_type), intent(in) :: upstream_grad
465 type(array_type) :: output
466
467 call output%allocate(array_shape = &
468 [ this%left_operand%shape, size(this%val, dim=2) ] &
469 )
470 call this%get_partial_left_val(upstream_grad%val, output%val)
471
472 end function get_partial_maxpool1d
473 !-------------------------------------------------------------------------------
474
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_maxpool1d_val(this, upstream_grad, output)
475 !! Optimised backward pass for 1D max pooling
476 implicit none
477
478 ! Arguments
479 class(array_type), intent(in) :: this
480 real(real32), dimension(:,:), intent(in) :: upstream_grad
481 real(real32), dimension(:,:), intent(out) :: output
482
483 ! Local variables
484 integer :: i, m, s, p
485 integer :: base_idx, max_idx, out_idx, input_h
486 real(real32) :: pool_max, grad_val
487 integer, dimension(3) :: input_shape
488 integer, dimension(1) :: pool_size, stride
489
490
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 2 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 3 times.
✓ Branch 15 taken 1 times.
6 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
491
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 pool_size(1) = this%adj_ja(1,1)
492
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 stride(1) = this%adj_ja(1,2)
493 1 input_h = input_shape(1)
494
495
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 54 times.
✓ Branch 21 taken 1 times.
56 output = 0._real32
496
497 2 do concurrent(s = 1:input_shape(3), m = 1:this%shape(2), &
498
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 i = 1:this%shape(1))
499
500 ! Compute indices once
501 24 base_idx = (i - 1) * stride(1) + (m - 1) * input_h
502
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
24 out_idx = i + (m - 1) * this%shape(1)
503
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
24 grad_val = upstream_grad(out_idx, s)
504
505 ! Find max value location - initialise with first element
506 24 max_idx = base_idx + 1
507
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
24 pool_max = this%left_operand%val(max_idx, s)
508
509 ! Search remaining elements for max
510
2/2
✓ Branch 0 taken 48 times.
✓ Branch 1 taken 24 times.
72 do p = 1, pool_size(1) - 1
511
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 48 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 48 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 48 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 48 times.
✓ Branch 12 taken 3 times.
✓ Branch 13 taken 45 times.
72 if(this%left_operand%val(base_idx + p + 1, s) .gt. pool_max)then
512
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
3 pool_max = this%left_operand%val(base_idx + p + 1, s)
513 3 max_idx = base_idx + p + 1
514 end if
515 end do
516
517 ! Assign gradient to max location
518
14/22
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 24 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 24 times.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 24 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 24 times.
57 output(max_idx, s) = output(max_idx, s) + grad_val
519 end do
520
521 1 end subroutine get_partial_maxpool1d_val
522 !###############################################################################
523
524
525 !###############################################################################
526 3 module function maxpool2d(input, pool_size, stride) result(output)
527 !! 2D max pooling operation
528 implicit none
529
530 ! Arguments
531 type(array_type), intent(in), target :: input
532 integer, dimension(2), intent(in) :: pool_size
533 integer, dimension(2), intent(in) :: stride
534 type(array_type), pointer :: output
535
536 ! Local variables
537 integer :: i, j, m, s, i_step, j_step
538 integer :: base_idx, stride_idx, idx, input_h
539 real(real32) :: pool_max, val_tmp
540 integer :: channel_size_in, channel_size_out
541 integer, dimension(4) :: output_shape
542
543 output_shape = [ &
544 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
545 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
546 input%shape(3), &
547
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✓ Branch 12 taken 12 times.
✓ Branch 13 taken 3 times.
15 size(input%val, dim=2)]
548 3 output => input%create_result(array_shape = output_shape)
549
550 ! Pre-compute as integers to avoid type conversion in loop
551
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 input_h = input%shape(1)
552
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 channel_size_in = input_h * input%shape(2)
553 3 channel_size_out = output_shape(1) * output_shape(2)
554
555 do concurrent(&
556 s = 1:output_shape(4), &
557 m = 1:output_shape(3), &
558 j = 1:output_shape(2), &
559 3 i = 1:output_shape(1))
560
561 ! Compute indices once per output position
562 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
563 264 (m-1) * channel_size_in
564 264 idx = i + (j - 1) * output_shape(1) + (m - 1) * channel_size_out
565
566 ! Find max value - initialise with first element for better performance
567 264 stride_idx = base_idx + 1
568
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 264 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 264 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 264 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 264 times.
264 pool_max = input%val(stride_idx, s)
569
570 ! Continue with remaining elements
571
2/2
✓ Branch 0 taken 720 times.
✓ Branch 1 taken 264 times.
984 do j_step = 0, pool_size(2)-1
572
2/2
✓ Branch 0 taken 2016 times.
✓ Branch 1 taken 720 times.
3000 do i_step = 0, pool_size(1)-1
573
2/2
✓ Branch 0 taken 264 times.
✓ Branch 1 taken 1752 times.
2016 if(i_step .eq. 0 .and. j_step .eq. 0) cycle ! Already processed
574 1752 stride_idx = base_idx + i_step + j_step * input_h + 1
575
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 1752 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1752 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1752 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1752 times.
✓ Branch 12 taken 59 times.
✓ Branch 13 taken 1693 times.
1752 if(input%val(stride_idx, s) .gt. pool_max) &
576
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 59 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 59 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 59 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 59 times.
779 pool_max = input%val(stride_idx, s)
577 end do
578 end do
579
580
12/16
✓ Branch 0 taken 14 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 82 times.
✓ Branch 3 taken 14 times.
✓ Branch 4 taken 264 times.
✓ Branch 5 taken 82 times.
✓ Branch 6 taken 264 times.
✓ Branch 7 taken 264 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 264 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 264 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 264 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 264 times.
627 output%val(idx, s) = pool_max
581 end do
582
583
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
3 allocate(output%adj_ja(2,2))
584
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 3 times.
9 output%adj_ja(:,1) = pool_size
585
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 3 times.
9 output%adj_ja(:,2) = stride
586
587 3 output%get_partial_left => get_partial_maxpool2d
588 3 output%get_partial_left_val => get_partial_maxpool2d_val
589
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(input%requires_grad)then
590 3 output%requires_grad = .true.
591 3 output%is_forward = input%is_forward
592 3 output%operation = 'maxpool'
593 3 output%left_operand => input
594 end if
595
596 3 end function maxpool2d
597 !-------------------------------------------------------------------------------
598 function get_partial_maxpool2d(this, upstream_grad) result(output)
599 !! Get the partial derivative for max pooling
600 implicit none
601
602 ! Arguments
603 class(array_type), intent(inout) :: this
604 type(array_type), intent(in) :: upstream_grad
605 type(array_type) :: output
606
607 call output%allocate(array_shape = &
608 [ this%left_operand%shape, size(this%val, dim=2) ] &
609 )
610 call this%get_partial_left_val(upstream_grad%val, output%val)
611
612 end function get_partial_maxpool2d
613 !-------------------------------------------------------------------------------
614
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_maxpool2d_val(this, upstream_grad, output)
615 implicit none
616
617 ! Arguments
618 class(array_type), intent(in) :: this
619 real(real32), dimension(:,:), intent(in) :: upstream_grad
620 real(real32), dimension(:,:), intent(out) :: output
621
622 ! Local variables
623 integer :: i, j, m, s
624 integer :: i_step, j_step
625 integer :: base_idx, in_idx, out_idx, max_idx, input_h
626 real(real32) :: pool_max, val_tmp, grad_val
627 integer :: channel_size_in, channel_size_out
628 integer, dimension(4) :: input_shape
629 integer, dimension(2) :: pool_size, stride
630
631 ! Unpack parameters
632
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 3 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 4 times.
✓ Branch 15 taken 1 times.
8 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
633
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 2 times.
✓ Branch 22 taken 1 times.
3 pool_size = this%adj_ja(:,1)
634
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 2 times.
✓ Branch 22 taken 1 times.
3 stride = this%adj_ja(:,2)
635 1 input_h = input_shape(1)
636 1 channel_size_in = input_h * input_shape(2)
637
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 channel_size_out = this%shape(1) * this%shape(2)
638
639
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 972 times.
✓ Branch 21 taken 1 times.
974 output = 0._real32
640
641 ! Parallelised over batch and spatial/channel dimensions
642 2 do concurrent(s = 1:input_shape(4), m = 1:this%shape(3), &
643
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 j = 1:this%shape(2), i = 1:this%shape(1))
644
645 ! Compute indices once
646 base_idx = (i-1) * stride(1) + ((j-1) * stride(2)) * input_h + &
647 192 (m-1) * channel_size_in
648
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
192 out_idx = i + (j-1) * this%shape(1) + (m-1) * channel_size_out
649
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
192 grad_val = upstream_grad(out_idx, s)
650
651 ! Find max value location - initialise with first element
652 192 max_idx = base_idx + 1
653
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 192 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 192 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 192 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 192 times.
192 pool_max = this%left_operand%val(max_idx, s)
654
655 ! Search remaining elements for max
656
2/2
✓ Branch 0 taken 576 times.
✓ Branch 1 taken 192 times.
768 do j_step = 0, pool_size(2) - 1
657
2/2
✓ Branch 0 taken 1728 times.
✓ Branch 1 taken 576 times.
2496 do i_step = 0, pool_size(1) - 1
658
2/2
✓ Branch 0 taken 192 times.
✓ Branch 1 taken 1536 times.
1728 if(i_step .eq. 0 .and. j_step .eq. 0) cycle ! Already processed
659 1536 in_idx = base_idx + i_step + j_step * input_h + 1
660
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1536 times.
1536 val_tmp = this%left_operand%val(in_idx, s)
661
662
2/2
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 1527 times.
2112 if(val_tmp .gt. pool_max)then
663 9 pool_max = val_tmp
664 9 max_idx = in_idx
665 end if
666 end do
667 end do
668
669 ! Assign gradient to max location
670
16/24
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 64 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 192 times.
✓ Branch 5 taken 64 times.
✓ Branch 6 taken 192 times.
✓ Branch 7 taken 192 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 192 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 192 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 192 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 192 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 192 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 192 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 192 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 192 times.
457 output(max_idx, s) = output(max_idx, s) + grad_val
671 end do
672
673 1 end subroutine get_partial_maxpool2d_val
674 !###############################################################################
675
676
677 !###############################################################################
678 1 module function maxpool3d(input, pool_size, stride) result(output)
679 !! 3D max pooling operation
680 implicit none
681
682 ! Arguments
683 type(array_type), intent(in), target :: input
684 integer, dimension(3), intent(in) :: pool_size
685 integer, dimension(3), intent(in) :: stride
686 type(array_type), pointer :: output
687
688 ! Local variables
689 integer :: i, j, k, m, s
690 integer :: i_step, j_step, k_step
691 integer :: stride_idx, idx
692 integer :: channel_size_in, channel_size_out
693 real(real32) :: pool_max
694 integer, dimension(5) :: output_shape
695
696 ! output_shape = [H_out, W_out, D_out, C, B]
697 output_shape = [ &
698 (input%shape(1) - pool_size(1)) / stride(1) + 1, &
699 (input%shape(2) - pool_size(2)) / stride(2) + 1, &
700 (input%shape(3) - pool_size(3)) / stride(3) + 1, &
701 input%shape(4), &
702
10/18
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 16 taken 5 times.
✓ Branch 17 taken 1 times.
6 size(input%val, dim=2) ]
703
704 1 output => input%create_result(array_shape = output_shape)
705
706 ! Pre-compute as integers
707
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 channel_size_in = input%shape(1) * input%shape(2) * input%shape(3)
708 1 channel_size_out = output_shape(1) * output_shape(2) * output_shape(3)
709
710 do concurrent( &
711 s = 1:output_shape(5), &
712 m = 1:output_shape(4), &
713 k = 1:output_shape(3), &
714 j = 1:output_shape(2), &
715 1 i = 1:output_shape(1))
716
717 ! Compute indices once per output position
718 stride_idx = ((i-1)*stride(1)) + &
719 ((j-1)*stride(2)) * input%shape(1) + &
720 ((k-1)*stride(3)) * input%shape(1) * input%shape(2) + &
721
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1536 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1536 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1536 times.
1536 (m-1) * channel_size_in + 1
722 idx = i + (j-1) * output_shape(1) + &
723 (k-1) * output_shape(1)*output_shape(2) + &
724 1536 (m-1) * channel_size_out
725
726 ! Find max value - initialise with first element
727
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1536 times.
1536 pool_max = input%val(stride_idx, s)
728
729
2/2
✓ Branch 0 taken 4608 times.
✓ Branch 1 taken 1536 times.
6144 do k_step = 0, pool_size(3)-1
730
2/2
✓ Branch 0 taken 13824 times.
✓ Branch 1 taken 4608 times.
19968 do j_step = 0, pool_size(2)-1
731
2/2
✓ Branch 0 taken 41472 times.
✓ Branch 1 taken 13824 times.
59904 do i_step = 0, pool_size(1)-1
732
2/2
✓ Branch 0 taken 1536 times.
✓ Branch 1 taken 39936 times.
41472 if(i_step .eq. 0 .and. j_step .eq. 0 .and. k_step .eq. 0) cycle
733
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 39936 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 39936 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 39936 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 39936 times.
✓ Branch 12 taken 21 times.
✓ Branch 13 taken 39915 times.
79872 if( &
734 input%val( &
735 stride_idx + i_step + &
736
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 39936 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 39936 times.
39936 j_step * input%shape(1) + &
737
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 39936 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 39936 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 39936 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 39936 times.
39936 k_step * input%shape(1) * input%shape(2), s &
738 ) .gt. pool_max &
739 13824 )then
740 pool_max = input%val(stride_idx + i_step + &
741 j_step * input%shape(1) + &
742
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 21 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 21 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 21 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 21 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 21 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 21 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 21 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 21 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 21 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 21 times.
21 k_step * input%shape(1) * input%shape(2), s)
743 end if
744 end do
745 end do
746 end do
747
748
14/18
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 64 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 512 times.
✓ Branch 5 taken 64 times.
✓ Branch 6 taken 1536 times.
✓ Branch 7 taken 512 times.
✓ Branch 8 taken 1536 times.
✓ Branch 9 taken 1536 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1536 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1536 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1536 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1536 times.
3657 output%val(idx, s) = pool_max
749 end do
750
751
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
1 allocate(output%adj_ja(3,2))
752
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 output%adj_ja(:,1) = pool_size
753
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 output%adj_ja(:,2) = stride
754
755 1 output%get_partial_left => get_partial_maxpool3d
756 1 output%get_partial_left_val => get_partial_maxpool3d_val
757
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(input%requires_grad)then
758 1 output%requires_grad = .true.
759 1 output%is_forward = input%is_forward
760 1 output%operation = 'maxpool3d'
761 1 output%left_operand => input
762 end if
763
764 1 end function maxpool3d
765 !-------------------------------------------------------------------------------
766 function get_partial_maxpool3d(this, upstream_grad) result(output)
767 !! Get the partial derivative for 3D max pooling
768 implicit none
769
770 ! Arguments
771 class(array_type), intent(inout) :: this
772 type(array_type), intent(in) :: upstream_grad
773 type(array_type) :: output
774
775 call output%allocate(array_shape = &
776 [ this%left_operand%shape, size(this%val, dim=2) ] &
777 )
778 call this%get_partial_left_val(upstream_grad%val, output%val)
779
780 end function get_partial_maxpool3d
781 !-------------------------------------------------------------------------------
782
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine get_partial_maxpool3d_val(this, upstream_grad, output)
783 !! Optimised backward pass for 3D max pooling
784 implicit none
785
786 ! Arguments
787 class(array_type), intent(in) :: this
788 real(real32), dimension(:,:), intent(in) :: upstream_grad
789 real(real32), dimension(:,:), intent(out) :: output
790
791 ! Local variables
792 integer :: i, j, k, m, s
793 integer :: i_step, j_step, k_step
794 integer :: base_idx, in_idx, out_idx, max_idx
795 integer :: input_h, input_hw
796 integer :: channel_size_in, channel_size_out
797 real(real32) :: pool_max, val_tmp, grad_val
798 integer, dimension(5) :: input_shape
799 integer, dimension(3) :: pool_size, stride
800
801 ! Unpack parameters
802
8/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 4 times.
✓ Branch 13 taken 1 times.
✓ Branch 14 taken 5 times.
✓ Branch 15 taken 1 times.
10 input_shape = [ this%left_operand%shape, size(this%val, dim=2) ]
803
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 pool_size = this%adj_ja(:,1)
804
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 stride = this%adj_ja(:,2)
805 1 input_h = input_shape(1)
806 1 input_hw = input_h * input_shape(2)
807 1 channel_size_in = input_hw * input_shape(3)
808
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 channel_size_out = this%shape(1) * this%shape(2) * this%shape(3)
809
810
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 17496 times.
✓ Branch 21 taken 1 times.
17498 output = 0._real32
811
812 ! Parallelised over batch and spatial/channel dimensions
813 2 do concurrent(s = 1:input_shape(5), m = 1:this%shape(4), &
814
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
1 k = 1:this%shape(3), j = 1:this%shape(2), i = 1:this%shape(1))
815
816 ! Compute indices once
817 base_idx = (i-1)*stride(1) + ((j-1)*stride(2)) * input_h + &
818 1536 ((k-1)*stride(3)) * input_hw + (m-1) * channel_size_in
819 3072 out_idx = i + (j-1) * this%shape(1) + &
820 6144 (k-1) * this%shape(1)*this%shape(2) + &
821
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1536 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1536 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1536 times.
1536 (m-1) * channel_size_out
822
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1536 times.
1536 grad_val = upstream_grad(out_idx, s)
823
824 ! Find max value location - initialise with first element
825 1536 max_idx = base_idx + 1
826
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1536 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1536 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1536 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1536 times.
1536 pool_max = this%left_operand%val(max_idx, s)
827
828 ! Search remaining elements for max
829
2/2
✓ Branch 0 taken 4608 times.
✓ Branch 1 taken 1536 times.
6144 do k_step = 0, pool_size(3)-1
830
2/2
✓ Branch 0 taken 13824 times.
✓ Branch 1 taken 4608 times.
19968 do j_step = 0, pool_size(2)-1
831
2/2
✓ Branch 0 taken 41472 times.
✓ Branch 1 taken 13824 times.
59904 do i_step = 0, pool_size(1)-1
832
2/2
✓ Branch 0 taken 1536 times.
✓ Branch 1 taken 39936 times.
41472 if(i_step .eq. 0 .and. j_step .eq. 0 .and. k_step .eq. 0) cycle
833 in_idx = base_idx + i_step + j_step * input_h + &
834 39936 k_step * input_hw + 1
835
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 39936 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 39936 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 39936 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 39936 times.
39936 val_tmp = this%left_operand%val(in_idx, s)
836
837
2/2
✓ Branch 0 taken 21 times.
✓ Branch 1 taken 39915 times.
53760 if(val_tmp .gt. pool_max)then
838 21 pool_max = val_tmp
839 21 max_idx = in_idx
840 end if
841 end do
842 end do
843 end do
844
845 ! Assign gradient to max location
846
18/26
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 64 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 512 times.
✓ Branch 5 taken 64 times.
✓ Branch 6 taken 1536 times.
✓ Branch 7 taken 512 times.
✓ Branch 8 taken 1536 times.
✓ Branch 9 taken 1536 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1536 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1536 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1536 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1536 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1536 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1536 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1536 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1536 times.
3657 output(max_idx, s) = output(max_idx, s) + grad_val
847 end do
848
849 1 end subroutine get_partial_maxpool3d_val
850 !###############################################################################
851
852 end submodule athena__diffstruc_extd_submodule_pool
853