GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_diffstruc_extd.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 1 1 100.0%
Functions: 0 0 -%
Branches: 39 61 63.9%

Line Branch Exec Source
1 module athena__diffstruc_extd
2 !! Module for extended differential structure types for Athena
3 use coreutils, only: real32
4 use diffstruc, only: array_type
5 use athena__misc_types, only: facets_type
6 implicit none
7
8
9 private
10
11 public :: array_ptr_type
12 public :: add_layers, concat_layers
13 public :: add_bias
14 public :: piecewise, softmax, swish
15 public :: huber
16 public :: avgpool1d, avgpool2d, avgpool3d
17 public :: maxpool1d, maxpool2d, maxpool3d
18 public :: pad1d, pad2d, pad3d
19 public :: merge_over_channels
20 public :: batchnorm_array_type, batchnorm, batchnorm_inference
21 public :: conv1d, conv2d, conv3d
22 public :: kipf_propagate, kipf_update
23 public :: duvenaud_propagate, duvenaud_update
24 public :: gno_kernel_eval, gno_aggregate
25 public :: lno_encode, lno_decode, elem_scale
26 public :: ono_encode, ono_decode
27
28
29 type, extends(array_type) :: batchnorm_array_type
30 real(real32), dimension(:), allocatable :: mean
31 real(real32), dimension(:), allocatable :: variance
32 real(real32) :: epsilon
33 end type batchnorm_array_type
34
35
36 !-------------------------------------------------------------------------------
37 ! Array container types
38 !-------------------------------------------------------------------------------
39 type :: array_ptr_type
40 type(array_type), pointer :: array(:,:)
41 end type array_ptr_type
42
43 ! Operator interfaces
44 !-----------------------------------------------------------------------------
45 interface add_layers
46 module function add_array_ptr(a, idx1, idx2) result(c)
47 type(array_ptr_type), dimension(:), intent(in) :: a
48 integer, intent(in) :: idx1, idx2
49 type(array_type), pointer :: c
50 end function add_array_ptr
51 end interface
52
53 interface concat_layers
54 module function concat_array_ptr(a, idx1, idx2, dim) result(c)
55 type(array_ptr_type), dimension(:), intent(in) :: a
56 integer, intent(in) :: idx1, idx2, dim
57 type(array_type), pointer :: c
58 end function concat_array_ptr
59 end interface
60 !-------------------------------------------------------------------------------
61
62
63 !-------------------------------------------------------------------------------
64 ! Activation functions and other operations
65 !-------------------------------------------------------------------------------
66 interface
67 module function add_bias(input, bias, dim, dim_act_on_shape) result(output)
68 class(array_type), intent(in), target :: input
69 class(array_type), intent(in), target :: bias
70 integer, intent(in) :: dim
71 logical, intent(in), optional :: dim_act_on_shape
72 type(array_type), pointer :: output
73 end function add_bias
74 end interface
75
76 interface piecewise
77 module function piecewise_array(input, gradient, limit) result( output )
78 class(array_type), intent(in), target :: input
79 real(real32), intent(in) :: gradient
80 real(real32), intent(in) :: limit
81 type(array_type), pointer :: output
82 end function piecewise_array
83 end interface
84
85 interface softmax
86 module function softmax_array(input, dim) result(output)
87 class(array_type), intent(in), target :: input
88 integer, intent(in) :: dim
89 type(array_type), pointer :: output
90 end function softmax_array
91 end interface
92
93 interface swish
94 module function swish_array(input, beta) result(output)
95 class(array_type), intent(in), target :: input
96 real(real32), intent(in) :: beta
97 type(array_type), pointer :: output
98 end function swish_array
99 end interface
100 !-------------------------------------------------------------------------------
101
102
103 !-------------------------------------------------------------------------------
104 ! Loss functions
105 !-------------------------------------------------------------------------------
106 interface huber
107 module function huber_array(delta, gamma) result( output )
108 class(array_type), intent(in), target :: delta
109 real(real32), intent(in) :: gamma
110 type(array_type), pointer :: output
111 end function huber_array
112 end interface
113 !-------------------------------------------------------------------------------
114
115
116 !-------------------------------------------------------------------------------
117 ! Layer operations
118 !-------------------------------------------------------------------------------
119 interface
120 module function avgpool1d(input, pool_size, stride) result(output)
121 type(array_type), intent(in), target :: input
122 integer, intent(in) :: pool_size
123 integer, intent(in) :: stride
124 type(array_type), pointer :: output
125 end function avgpool1d
126
127 module function avgpool2d(input, pool_size, stride) result(output)
128 type(array_type), intent(in), target :: input
129 integer, dimension(2), intent(in) :: pool_size
130 integer, dimension(2), intent(in) :: stride
131 type(array_type), pointer :: output
132 end function avgpool2d
133
134 module function avgpool3d(input, pool_size, stride) result(output)
135 type(array_type), intent(in), target :: input
136 integer, dimension(3), intent(in) :: pool_size
137 integer, dimension(3), intent(in) :: stride
138 type(array_type), pointer :: output
139 end function avgpool3d
140 end interface
141
142 interface
143 module function maxpool1d(input, pool_size, stride) result(output)
144 type(array_type), intent(in), target :: input
145 integer, intent(in) :: pool_size
146 integer, intent(in) :: stride
147 type(array_type), pointer :: output
148 end function maxpool1d
149
150 module function maxpool2d(input, pool_size, stride) result(output)
151 type(array_type), intent(in), target :: input
152 integer, dimension(2), intent(in) :: pool_size
153 integer, dimension(2), intent(in) :: stride
154 type(array_type), pointer :: output
155 end function maxpool2d
156
157 module function maxpool3d(input, pool_size, stride) result(output)
158 type(array_type), intent(in), target :: input
159 integer, dimension(3), intent(in) :: pool_size
160 integer, dimension(3), intent(in) :: stride
161 type(array_type), pointer :: output
162 end function maxpool3d
163 end interface
164
165 interface
166 module function pad1d(input, facets, pad_size, imethod) result(output)
167 type(array_type), intent(in), target :: input
168 type(facets_type), intent(in) :: facets
169 integer, intent(in) :: pad_size
170 integer, intent(in) :: imethod
171 type(array_type), pointer :: output
172 end function pad1d
173
174 module function pad2d(input, facets, pad_size, imethod) result(output)
175 type(array_type), intent(in), target :: input
176 type(facets_type), dimension(2), intent(in) :: facets
177 integer, dimension(2), intent(in) :: pad_size
178 integer, intent(in) :: imethod
179 type(array_type), pointer :: output
180 end function pad2d
181
182 module function pad3d(input, facets, pad_size, imethod) result(output)
183 type(array_type), intent(in), target :: input
184 type(facets_type), dimension(3), intent(in) :: facets
185 integer, dimension(3), intent(in) :: pad_size
186 integer, intent(in) :: imethod
187 type(array_type), pointer :: output
188 end function pad3d
189 end interface
190
191 interface merge_over_channels
192 module function merge_scalar_over_channels(tsource, fsource, mask) result(output)
193 class(array_type), intent(in), target :: tsource
194 real(real32), intent(in) :: fsource
195 logical, dimension(:,:), intent(in) :: mask
196 type(array_type), pointer :: output
197 end function merge_scalar_over_channels
198 end interface
199
200 interface
201 module function batchnorm( &
202 input, params, momentum, mean, variance, epsilon &
203 ) result( output )
204 class(array_type), intent(in), target :: input
205 class(array_type), intent(in), target :: params
206 real(real32), intent(in) :: momentum
207 real(real32), dimension(:), intent(in) :: mean
208 real(real32), dimension(:), intent(in) :: variance
209 real(real32), intent(in) :: epsilon
210 type(batchnorm_array_type), pointer :: output
211 end function batchnorm
212
213 module function batchnorm_inference( &
214 input, params, mean, variance, epsilon &
215 ) result( output )
216 class(array_type), intent(in), target :: input
217 class(array_type), intent(in), target :: params
218 real(real32), dimension(:), intent(in) :: mean
219 real(real32), dimension(:), intent(in) :: variance
220 real(real32), intent(in) :: epsilon
221 type(batchnorm_array_type), pointer :: output
222 end function batchnorm_inference
223 end interface
224
225 interface
226 module function conv1d(input, kernel, stride, dilation) result(output)
227 type(array_type), intent(in), target :: input
228 type(array_type), intent(in), target :: kernel
229 integer, intent(in) :: stride
230 integer, intent(in) :: dilation
231 type(array_type), pointer :: output
232 end function conv1d
233
234 module function conv2d(input, kernel, stride, dilation) result(output)
235 type(array_type), intent(in), target :: input
236 type(array_type), intent(in), target :: kernel
237 integer, dimension(2), intent(in) :: stride
238 integer, dimension(2), intent(in) :: dilation
239 type(array_type), pointer :: output
240 end function conv2d
241
242 module function conv3d(input, kernel, stride, dilation) result(output)
243 type(array_type), intent(in), target :: input
244 type(array_type), intent(in), target :: kernel
245 integer, dimension(3), intent(in) :: stride
246 integer, dimension(3), intent(in) :: dilation
247 type(array_type), pointer :: output
248 end function conv3d
249 end interface
250
251 interface
252 module function kipf_propagate(vertex_features, adj_ia, adj_ja) result(c)
253 !! Propagate values from one autodiff array to another
254 class(array_type), intent(in), target :: vertex_features
255 integer, dimension(:), intent(in) :: adj_ia
256 integer, dimension(:,:), intent(in) :: adj_ja
257 type(array_type), pointer :: c
258 end function kipf_propagate
259
260 module function kipf_update(a, weight, adj_ia) result(c)
261 !! Update the message passing layer
262 class(array_type), intent(in), target :: a
263 class(array_type), intent(in), target :: weight
264 integer, dimension(:), intent(in) :: adj_ia
265 type(array_type), pointer :: c
266 end function kipf_update
267 end interface
268
269 interface
270 module function duvenaud_propagate( &
271 vertex_features, edge_features, adj_ia, adj_ja &
272 ) result(c)
273 !! Duvenaud message passing function
274 class(array_type), intent(in), target :: vertex_features
275 class(array_type), intent(in), target :: edge_features
276 integer, dimension(:), intent(in) :: adj_ia
277 integer, dimension(:,:), intent(in) :: adj_ja
278 type(array_type), pointer :: c
279 end function duvenaud_propagate
280
281 module function duvenaud_update( &
282 a, weight, adj_ia, min_degree, max_degree &
283 ) result(c)
284 !! Duvenaud update function
285 class(array_type), intent(in), target :: a
286 class(array_type), intent(in), target :: weight
287 integer, dimension(:), intent(in) :: adj_ia
288 integer, intent(in) :: min_degree, max_degree
289 type(array_type), pointer :: c
290 end function duvenaud_update
291 end interface
292
293 interface
294 module function gno_kernel_eval( &
295 coords, kernel_params, adj_ia, adj_ja, &
296 coord_dim, kernel_hidden, F_in, F_out &
297 ) result(c)
298 !! Evaluate GNO kernel MLP on every edge
299 class(array_type), intent(in), target :: coords
300 class(array_type), intent(in), target :: kernel_params
301 integer, dimension(:), intent(in) :: adj_ia
302 integer, dimension(:,:), intent(in) :: adj_ja
303 integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out
304 type(array_type), pointer :: c
305 end function gno_kernel_eval
306
307 module function gno_aggregate( &
308 features, edge_kernels, adj_ia, adj_ja, F_in, F_out &
309 ) result(c)
310 !! Aggregate neighbour messages using per-edge kernels
311 class(array_type), intent(in), target :: features
312 class(array_type), intent(in), target :: edge_kernels
313 integer, dimension(:), intent(in) :: adj_ia
314 integer, dimension(:,:), intent(in) :: adj_ja
315 integer, intent(in) :: F_in, F_out
316 type(array_type), pointer :: c
317 end function gno_aggregate
318 end interface
319
320 interface
321 module function lno_encode( &
322 input, poles, num_inputs, num_modes &
323 ) result(c)
324 !! Encode input via Laplace basis: E(mu) @ u
325 class(array_type), intent(in), target :: input
326 class(array_type), intent(in), target :: poles
327 integer, intent(in) :: num_inputs, num_modes
328 type(array_type), pointer :: c
329 end function lno_encode
330
331 module function lno_decode( &
332 spectral, poles, num_outputs, num_modes &
333 ) result(c)
334 !! Decode via Laplace basis: D(mu) @ spectral
335 class(array_type), intent(in), target :: spectral
336 class(array_type), intent(in), target :: poles
337 integer, intent(in) :: num_outputs, num_modes
338 type(array_type), pointer :: c
339 end function lno_decode
340 end interface
341
342 interface
343 module function elem_scale(input, scale) result(c)
344 !! Element-wise multiply: out[i,s] = input[i,s] * scale[i,1]
345 !! Correctly handles non-sample-dependent scale vectors.
346 class(array_type), intent(in), target :: input
347 class(array_type), intent(in), target :: scale
348 type(array_type), pointer :: c
349 end function elem_scale
350 end interface
351
352 interface
353 module function ono_encode( &
354 input, basis_weights, num_inputs, num_basis &
355 ) result(c)
356 !! Encode via orthogonal basis: Q(B)^T @ u
357 class(array_type), intent(in), target :: input
358 class(array_type), intent(in), target :: basis_weights
359 integer, intent(in) :: num_inputs, num_basis
360 type(array_type), pointer :: c
361 end function ono_encode
362
363 module function ono_decode( &
364 mixed, basis_weights, num_inputs, num_basis &
365 ) result(c)
366 !! Decode via orthogonal basis: Q(B) @ mixed
367 class(array_type), intent(in), target :: mixed
368 class(array_type), intent(in), target :: basis_weights
369 integer, intent(in) :: num_inputs, num_basis
370 type(array_type), pointer :: c
371 end function ono_decode
372 end interface
373 !-------------------------------------------------------------------------------
374
375
39/61
✓ Branch 0 taken 36 times.
✓ Branch 1 taken 33 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 84 times.
✓ Branch 4 taken 33 times.
✓ Branch 5 taken 36 times.
✓ Branch 6 taken 48 times.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 84 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 36 times.
✓ Branch 12 taken 48 times.
✓ Branch 13 taken 36 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 36 times.
✓ Branch 18 taken 48 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 48 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 48 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 48 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 48 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 48 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 48 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 48 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 33 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 33 times.
✓ Branch 45 taken 33 times.
✓ Branch 46 taken 33 times.
✓ Branch 47 taken 48 times.
✓ Branch 48 taken 33 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 48 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 48 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 48 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 48 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 48 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 48 times.
✓ Branch 67 taken 33 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 21 times.
✓ Branch 70 taken 12 times.
✓ Branch 71 taken 33 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 21 times.
✓ Branch 74 taken 12 times.
396 end module athena__diffstruc_extd
376