GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__diffstruc_extd
2 !! Module for extended differential structure types for Athena
3 use coreutils, only: real32
4 use diffstruc, only: array_type
5 use athena__misc_types, only: facets_type
6 implicit none
7
8
9 private
10
11 public :: array_ptr_type
12 public :: add_layers, concat_layers
13 public :: add_bias
14 public :: piecewise, softmax, swish
15 public :: huber
16 public :: avgpool1d, avgpool2d, avgpool3d
17 public :: maxpool1d, maxpool2d, maxpool3d
18 public :: pad1d, pad2d, pad3d
19 public :: merge_over_channels
20 public :: batchnorm_array_type, batchnorm, batchnorm_inference
21 public :: conv1d, conv2d, conv3d
22 public :: kipf_propagate, kipf_update
23 public :: duvenaud_propagate, duvenaud_update
24
25
26 type, extends(array_type) :: batchnorm_array_type
27 real(real32), dimension(:), allocatable :: mean
28 real(real32), dimension(:), allocatable :: variance
29 real(real32) :: epsilon
30 end type batchnorm_array_type
31
32
33 !-------------------------------------------------------------------------------
34 ! Array container types
35 !-------------------------------------------------------------------------------
36 type :: array_ptr_type
37 type(array_type), pointer :: array(:,:)
38 end type array_ptr_type
39
40 ! Operator interfaces
41 !-----------------------------------------------------------------------------
42 interface add_layers
43 module function add_array_ptr(a, idx1, idx2) result(c)
44 type(array_ptr_type), dimension(:), intent(in) :: a
45 integer, intent(in) :: idx1, idx2
46 type(array_type), pointer :: c
47 end function add_array_ptr
48 end interface
49
50 interface concat_layers
51 module function concat_array_ptr(a, idx1, idx2, dim) result(c)
52 type(array_ptr_type), dimension(:), intent(in) :: a
53 integer, intent(in) :: idx1, idx2, dim
54 type(array_type), pointer :: c
55 end function concat_array_ptr
56 end interface
57 !-------------------------------------------------------------------------------
58
59
60 !-------------------------------------------------------------------------------
61 ! Activation functions and other operations
62 !-------------------------------------------------------------------------------
63 interface
64 module function add_bias(input, bias, dim, dim_act_on_shape) result(output)
65 class(array_type), intent(in), target :: input
66 class(array_type), intent(in), target :: bias
67 integer, intent(in) :: dim
68 logical, intent(in), optional :: dim_act_on_shape
69 type(array_type), pointer :: output
70 end function add_bias
71 end interface
72
73 interface piecewise
74 module function piecewise_array(input, gradient, limit) result( output )
75 class(array_type), intent(in), target :: input
76 real(real32), intent(in) :: gradient
77 real(real32), intent(in) :: limit
78 type(array_type), pointer :: output
79 end function piecewise_array
80 end interface
81
82 interface softmax
83 module function softmax_array(input, dim) result(output)
84 class(array_type), intent(in), target :: input
85 integer, intent(in) :: dim
86 type(array_type), pointer :: output
87 end function softmax_array
88 end interface
89
90 interface swish
91 module function swish_array(input, beta) result(output)
92 class(array_type), intent(in), target :: input
93 real(real32), intent(in) :: beta
94 type(array_type), pointer :: output
95 end function swish_array
96 end interface
97 !-------------------------------------------------------------------------------
98
99
100 !-------------------------------------------------------------------------------
101 ! Loss functions
102 !-------------------------------------------------------------------------------
103 interface huber
104 module function huber_array(delta, gamma) result( output )
105 class(array_type), intent(in), target :: delta
106 real(real32), intent(in) :: gamma
107 type(array_type), pointer :: output
108 end function huber_array
109 end interface
110 !-------------------------------------------------------------------------------
111
112
113 !-------------------------------------------------------------------------------
114 ! Layer operations
115 !-------------------------------------------------------------------------------
116 interface
117 module function avgpool1d(input, pool_size, stride) result(output)
118 type(array_type), intent(in), target :: input
119 integer, intent(in) :: pool_size
120 integer, intent(in) :: stride
121 type(array_type), pointer :: output
122 end function avgpool1d
123
124 module function avgpool2d(input, pool_size, stride) result(output)
125 type(array_type), intent(in), target :: input
126 integer, dimension(2), intent(in) :: pool_size
127 integer, dimension(2), intent(in) :: stride
128 type(array_type), pointer :: output
129 end function avgpool2d
130
131 module function avgpool3d(input, pool_size, stride) result(output)
132 type(array_type), intent(in), target :: input
133 integer, dimension(3), intent(in) :: pool_size
134 integer, dimension(3), intent(in) :: stride
135 type(array_type), pointer :: output
136 end function avgpool3d
137 end interface
138
139 interface
140 module function maxpool1d(input, pool_size, stride) result(output)
141 type(array_type), intent(in), target :: input
142 integer, intent(in) :: pool_size
143 integer, intent(in) :: stride
144 type(array_type), pointer :: output
145 end function maxpool1d
146
147 module function maxpool2d(input, pool_size, stride) result(output)
148 type(array_type), intent(in), target :: input
149 integer, dimension(2), intent(in) :: pool_size
150 integer, dimension(2), intent(in) :: stride
151 type(array_type), pointer :: output
152 end function maxpool2d
153
154 module function maxpool3d(input, pool_size, stride) result(output)
155 type(array_type), intent(in), target :: input
156 integer, dimension(3), intent(in) :: pool_size
157 integer, dimension(3), intent(in) :: stride
158 type(array_type), pointer :: output
159 end function maxpool3d
160 end interface
161
162 interface
163 module function pad1d(input, facets, pad_size, imethod) result(output)
164 type(array_type), intent(in), target :: input
165 type(facets_type), intent(in) :: facets
166 integer, intent(in) :: pad_size
167 integer, intent(in) :: imethod
168 type(array_type), pointer :: output
169 end function pad1d
170
171 module function pad2d(input, facets, pad_size, imethod) result(output)
172 type(array_type), intent(in), target :: input
173 type(facets_type), dimension(2), intent(in) :: facets
174 integer, dimension(2), intent(in) :: pad_size
175 integer, intent(in) :: imethod
176 type(array_type), pointer :: output
177 end function pad2d
178
179 module function pad3d(input, facets, pad_size, imethod) result(output)
180 type(array_type), intent(in), target :: input
181 type(facets_type), dimension(3), intent(in) :: facets
182 integer, dimension(3), intent(in) :: pad_size
183 integer, intent(in) :: imethod
184 type(array_type), pointer :: output
185 end function pad3d
186 end interface
187
188 interface merge_over_channels
189 module function merge_scalar_over_channels(tsource, fsource, mask) result(output)
190 class(array_type), intent(in), target :: tsource
191 real(real32), intent(in) :: fsource
192 logical, dimension(:,:), intent(in) :: mask
193 type(array_type), pointer :: output
194 end function merge_scalar_over_channels
195 end interface
196
197 interface
198 module function batchnorm( &
199 input, params, momentum, mean, variance, epsilon &
200 ) result( output )
201 class(array_type), intent(in), target :: input
202 class(array_type), intent(in), target :: params
203 real(real32), intent(in) :: momentum
204 real(real32), dimension(:), intent(in) :: mean
205 real(real32), dimension(:), intent(in) :: variance
206 real(real32), intent(in) :: epsilon
207 type(batchnorm_array_type), pointer :: output
208 end function batchnorm
209
210 module function batchnorm_inference( &
211 input, params, mean, variance, epsilon &
212 ) result( output )
213 class(array_type), intent(in), target :: input
214 class(array_type), intent(in), target :: params
215 real(real32), dimension(:), intent(in) :: mean
216 real(real32), dimension(:), intent(in) :: variance
217 real(real32), intent(in) :: epsilon
218 type(batchnorm_array_type), pointer :: output
219 end function batchnorm_inference
220 end interface
221
222 interface
223 module function conv1d(input, kernel, stride, dilation) result(output)
224 type(array_type), intent(in), target :: input
225 type(array_type), intent(in), target :: kernel
226 integer, intent(in) :: stride
227 integer, intent(in) :: dilation
228 type(array_type), pointer :: output
229 end function conv1d
230
231 module function conv2d(input, kernel, stride, dilation) result(output)
232 type(array_type), intent(in), target :: input
233 type(array_type), intent(in), target :: kernel
234 integer, dimension(2), intent(in) :: stride
235 integer, dimension(2), intent(in) :: dilation
236 type(array_type), pointer :: output
237 end function conv2d
238
239 module function conv3d(input, kernel, stride, dilation) result(output)
240 type(array_type), intent(in), target :: input
241 type(array_type), intent(in), target :: kernel
242 integer, dimension(3), intent(in) :: stride
243 integer, dimension(3), intent(in) :: dilation
244 type(array_type), pointer :: output
245 end function conv3d
246 end interface
247
248 interface
249 module function kipf_propagate(vertex_features, adj_ia, adj_ja) result(c)
250 !! Propagate values from one autodiff array to another
251 class(array_type), intent(in), target :: vertex_features
252 integer, dimension(:), intent(in) :: adj_ia
253 integer, dimension(:,:), intent(in) :: adj_ja
254 type(array_type), pointer :: c
255 end function kipf_propagate
256
257 module function kipf_update(a, weight, adj_ia) result(c)
258 !! Update the message passing layer
259 class(array_type), intent(in), target :: a
260 class(array_type), intent(in), target :: weight
261 integer, dimension(:), intent(in) :: adj_ia
262 type(array_type), pointer :: c
263 end function kipf_update
264 end interface
265
266 interface
267 module function duvenaud_propagate( &
268 vertex_features, edge_features, adj_ia, adj_ja &
269 ) result(c)
270 !! Duvenaud message passing function
271 class(array_type), intent(in), target :: vertex_features
272 class(array_type), intent(in), target :: edge_features
273 integer, dimension(:), intent(in) :: adj_ia
274 integer, dimension(:,:), intent(in) :: adj_ja
275 type(array_type), pointer :: c
276 end function duvenaud_propagate
277
278 module function duvenaud_update( &
279 a, weight, adj_ia, min_degree, max_degree &
280 ) result(c)
281 !! Duvenaud update function
282 class(array_type), intent(in), target :: a
283 class(array_type), intent(in), target :: weight
284 integer, dimension(:), intent(in) :: adj_ia
285 integer, intent(in) :: min_degree, max_degree
286 type(array_type), pointer :: c
287 end function duvenaud_update
288 end interface
289 !-------------------------------------------------------------------------------
290
291 end module athena__diffstruc_extd
292