GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_duvenaud_msgpass_layer.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 328 396 82.8%
Functions: 0 0 -%
Branches: 1216 2633 46.2%

Line Branch Exec Source
1 module athena__duvenaud_msgpass_layer
2 !! Module implementing Duvenaud message passing for molecular graphs
3 !!
4 !! This module implements the graph neural network architecture from
5 !! Duvenaud et al. (2015) for learning on molecular graphs with both
6 !! vertex (node) and edge features.
7 !!
8 !! Mathematical operation (per time step t):
9 !! \[ h_v^{(t+1)} = \sigma\left( h_v^{(t)} + \sum_{u \in \mathcal{N}(v)} M(h_v^{(t)}, h_u^{(t)}, e_{vu}) \right) \]
10 !!
11 !! Graph readout (aggregation to fixed-size vector):
12 !! \[ h_{\text{graph}} = \sigma_{\text{readout}}\left( \sum_{d=1}^D \sum_{v:\deg(v)=d} W_d h_v^{(T)} \right) \]
13 !!
14 !! where \( M \) is a learned message function, \( \sigma \) is activation function,
15 !! \( \mathcal{N}(v) \) are neighbors of \( v \), \( e_{vu} \) are edge features, \( W_d \) are
16 !! degree-specific weight matrices, and \( D \) is max vertex degree.
17 !!
18 !! Reference: Duvenaud et al. (2015), NeurIPS
19 use coreutils, only: real32
20 use graphstruc, only: graph_type
21 use athena__misc_types, only: base_actv_type, base_init_type, onnx_attribute_type, &
22 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
23 use diffstruc, only: array_type, sum, matmul, operator(+)
24 use athena__base_layer, only: base_layer_type
25 use athena__msgpass_layer, only: msgpass_layer_type
26 use athena__diffstruc_extd, only: duvenaud_propagate, duvenaud_update
27 implicit none
28
29
30 private
31
32 public :: duvenaud_msgpass_layer_type
33 public :: read_duvenaud_msgpass_layer
34
35
36 !-------------------------------------------------------------------------------
37 ! Message passing layer
38 !-------------------------------------------------------------------------------
39 type, extends(msgpass_layer_type) :: duvenaud_msgpass_layer_type
40
41 integer :: min_vertex_degree = 1
42 integer :: max_vertex_degree = 0
43 !! Maximum vertex degree
44
45 class(base_actv_type), allocatable :: activation_readout
46 !! Activation function
47 type(array_type), allocatable, dimension(:,:) :: z
48 type(array_type), allocatable, dimension(:,:) :: z_readout
49 !! Input gradients
50
51 contains
52 procedure, pass(this) :: get_num_params => get_num_params_duvenaud
53 !! Get the number of parameters for the message passing layer
54 procedure, pass(this) :: get_attributes => get_attributes_duvenaud
55 !! Get the attributes of the layer (for ONNX export)
56 procedure, pass(this) :: set_hyperparams => set_hyperparams_duvenaud
57 !! Set the hyperparameters for the message passing layer
58 procedure, pass(this) :: init => init_duvenaud
59 !! Initialise the message passing layer
60 procedure, pass(this) :: print_to_unit => print_to_unit_duvenaud
61 ! !! Print the message passing layer
62 procedure, pass(this) :: read => read_duvenaud
63 !! Read the message passing layer
64
65 procedure, pass(this) :: set_graph => set_graph_duvenaud
66 !! Set the graph for the message passing layer
67
68 procedure, pass(this) :: update_message => update_message_duvenaud
69 !! Update the message
70
71 procedure, pass(this) :: update_readout => update_readout_duvenaud
72 !! Update the readout
73
74 procedure, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_duvenaud
75 !! Emit ONNX JSON nodes for Duvenaud GNN layer
76 procedure, pass(this) :: emit_onnx_graph_inputs => &
77 emit_onnx_graph_inputs_duvenaud
78 !! Emit graph input tensor declarations for Duvenaud GNN layer
79
80 final :: finalise_duvenaud
81 !! Finalise the message passing layer
82 end type duvenaud_msgpass_layer_type
83
84 ! Interface for setting up the MPNN layer
85 !-----------------------------------------------------------------------------
86 interface duvenaud_msgpass_layer_type
87 !! Interface for setting up the MPNN layer
88 module function layer_setup( &
89 num_vertex_features, num_edge_features, num_time_steps, &
90 max_vertex_degree, &
91 num_outputs, &
92 min_vertex_degree, &
93 message_activation, &
94 readout_activation, &
95 kernel_initialiser, &
96 verbose &
97 ) result(layer)
98 !! Set up the message passing layer
99 integer, dimension(:), intent(in) :: num_vertex_features
100 !! Number of vertex features
101 integer, dimension(:), intent(in) :: num_edge_features
102 !! Number of edge features
103 integer, intent(in) :: num_time_steps
104 !! Number of time steps
105 integer, intent(in) :: max_vertex_degree
106 !! Maximum vertex degree
107 integer, intent(in) :: num_outputs
108 !! Number of outputs
109 integer, optional, intent(in) :: min_vertex_degree
110 !! Minimum vertex degree
111 class(*), optional, intent(in) :: message_activation, &
112 readout_activation
113 !! Message and readout activation functions
114 character(*), optional, intent(in) :: kernel_initialiser
115 !!! Kernel initialiser
116 integer, optional, intent(in) :: verbose
117 !! Verbosity level
118 type(duvenaud_msgpass_layer_type) :: layer
119 !! Instance of the message passing layer
120 end function layer_setup
121 end interface duvenaud_msgpass_layer_type
122
123 character(len=*), parameter :: default_message_actv_name = "sigmoid"
124 character(len=*), parameter :: default_readout_actv_name = "softmax"
125
126
127
128 contains
129
130 !###############################################################################
131 1 function get_attributes_duvenaud(this) result(attributes)
132 !! Get the attributes of the Duvenaud message passing layer (for ONNX export)
133 !!
134 !! Exports hyperparameters needed to reconstruct the layer architecture:
135 !! - num_time_steps: number of message passing iterations
136 !! - min_vertex_degree, max_vertex_degree: degree bucket range
137 !! - num_vertex_features: vertex feature dimensions per time step
138 !! - num_edge_features: edge feature dimensions per time step
139 !! - num_outputs: readout output dimension
140 !! - message_activation, readout_activation: activation function names
141 implicit none
142
143 ! Arguments
144 class(duvenaud_msgpass_layer_type), intent(in) :: this
145 !! Instance of the layer
146 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
147 !! Attributes of the layer
148
149 ! Local variables
150 integer :: t
151 !! Time-step index
152 character(256) :: buffer
153 !! Buffer for integer-to-string conversion
154
155
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 7 times.
✓ Branch 22 taken 1 times.
✓ Branch 23 taken 7 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 7 times.
8 allocate(attributes(7))
156
157 1 write(buffer, '(I0)') this%num_time_steps
158 attributes(1) = onnx_attribute_type( &
159
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='num_time_steps', type='int', val=trim(buffer))
160
161 1 write(buffer, '(I0)') this%min_vertex_degree
162 attributes(2) = onnx_attribute_type( &
163
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='min_vertex_degree', type='int', val=trim(buffer))
164
165 1 write(buffer, '(I0)') this%max_vertex_degree
166 attributes(3) = onnx_attribute_type( &
167
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='max_vertex_degree', type='int', val=trim(buffer))
168
169 1 buffer = ''
170
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
4 do t = 0, this%num_time_steps
171
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
4 if(t .eq. 0)then
172
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
1 write(buffer, '(I0)') this%num_vertex_features(t)
173 else
174
3/6
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
2 write(buffer, '(A," ",I0)') trim(buffer), this%num_vertex_features(t)
175 end if
176 end do
177 attributes(4) = onnx_attribute_type( &
178
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='num_vertex_features', type='ints', val=trim(buffer))
179
180 1 buffer = ''
181
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
4 do t = 0, this%num_time_steps
182
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
4 if(t .eq. 0)then
183
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
1 write(buffer, '(I0)') this%num_edge_features(t)
184 else
185
3/6
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
2 write(buffer, '(A," ",I0)') trim(buffer), this%num_edge_features(t)
186 end if
187 end do
188 attributes(5) = onnx_attribute_type( &
189
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='num_edge_features', type='ints', val=trim(buffer))
190
191 1 write(buffer, '(I0)') this%num_outputs
192 attributes(6) = onnx_attribute_type( &
193
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 name='num_outputs', type='int', val=trim(buffer))
194
195 attributes(7) = onnx_attribute_type( &
196 name='message_activation', type='string', &
197
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 val=trim(this%activation%name))
198
199 1 end function get_attributes_duvenaud
200 !###############################################################################
201
202
203 !###############################################################################
204 18 subroutine finalise_duvenaud(this)
205 !! Finalise the message passing layer
206 implicit none
207
208 ! Arguments
209 type(duvenaud_msgpass_layer_type), intent(inout) :: this
210 !! Instance of the fully connected layer
211
212
3/4
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
18 if(allocated(this%input_shape)) deallocate(this%input_shape)
213
3/4
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
18 if(allocated(this%output_shape)) deallocate(this%output_shape)
214
4/6
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 16 times.
✓ Branch 5 taken 16 times.
✗ Branch 6 not taken.
18 if(allocated(this%output)) deallocate(this%output)
215
216 18 end subroutine finalise_duvenaud
217 !###############################################################################
218
219
220 !##############################################################################!
221 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
222 !##############################################################################!
223
224
225 !###############################################################################
226 6 pure function get_num_params_duvenaud(this) result(num_params)
227 !! Get the number of parameters for the message passing layer
228 !!
229 !! This function calculates the number of parameters for the message passing
230 !! layer.
231 !! This procedure is based on code from the neural-fortran library
232 implicit none
233
234 ! Arguments
235 class(duvenaud_msgpass_layer_type), intent(in) :: this
236 !! Instance of the message passing layer
237 integer :: num_params
238 !! Number of parameters
239
240 24 num_params = ( this%num_vertex_features(0) + this%num_edge_features(0) ) * &
241 12 this%num_vertex_features(0) * &
242 ( this%max_vertex_degree - this%min_vertex_degree + 1 ) * &
243 this%num_time_steps + &
244
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
6 this%num_vertex_features(0) * this%num_outputs * this%num_time_steps
245
246 6 end function get_num_params_duvenaud
247 !###############################################################################
248
249
250 !##############################################################################!
251 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
252 !##############################################################################!
253
254
255 !###############################################################################
256 6 module function layer_setup( &
257
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 num_vertex_features ,num_edge_features, num_time_steps, &
258 max_vertex_degree, &
259 num_outputs, &
260 min_vertex_degree, &
261 message_activation, &
262 readout_activation, &
263 kernel_initialiser, &
264 verbose &
265 6 ) result(layer)
266 !! Set up the message passing layer
267 use athena__initialiser, only: initialiser_setup
268 use athena__activation, only: activation_setup
269 implicit none
270
271 ! Arguments
272 integer, dimension(:), intent(in) :: num_vertex_features
273 !! Number of vertex features
274 integer, dimension(:), intent(in) :: num_edge_features
275 !! Number of edge features
276 integer, intent(in) :: num_time_steps
277 !! Number of time steps
278 integer, intent(in) :: max_vertex_degree
279 !! Maximum vertex degree
280 integer, intent(in) :: num_outputs
281 !! Number of outputs
282 integer, optional, intent(in) :: min_vertex_degree
283 !! Minimum vertex degree
284 class(*), optional, intent(in) :: message_activation, &
285 readout_activation
286 !! Message and readout activation functions
287 character(*), optional, intent(in) :: kernel_initialiser
288 !!! Kernel initialiser
289 integer, optional, intent(in) :: verbose
290 !! Verbosity level
291 type(duvenaud_msgpass_layer_type) :: layer
292 !! Instance of the message passing layer
293
294 ! Local variables
295 integer :: verbose_ = 0
296 !! Verbosity level
297 30 class(base_actv_type), allocatable :: message_activation_ , readout_activation_
298 !! Activation function
299 13 class(base_init_type), allocatable :: kernel_initialiser_
300 !! Kernel and bias initialisers
301 integer :: min_vertex_degree_ = 1
302 !! Minimum vertex degree
303
304
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(present(verbose)) verbose_ = verbose
305
306
307 !---------------------------------------------------------------------------
308 ! Set activation functions
309 !---------------------------------------------------------------------------
310
3/4
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
6 if(present(message_activation))then
311
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 17 taken 2 times.
✗ Branch 18 not taken.
2 message_activation_ = activation_setup(message_activation)
312 else
313
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4 times.
✓ Branch 17 taken 4 times.
✗ Branch 18 not taken.
4 message_activation_ = activation_setup(default_message_actv_name)
314 end if
315
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 5 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
6 if(present(readout_activation))then
316
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 readout_activation_ = activation_setup(readout_activation)
317 else
318
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✓ Branch 17 taken 5 times.
✗ Branch 18 not taken.
5 readout_activation_ = activation_setup(default_readout_actv_name)
319 end if
320
321
322 !---------------------------------------------------------------------------
323 ! Set minimum vertex degree
324 !---------------------------------------------------------------------------
325
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
6 if(present(min_vertex_degree)) min_vertex_degree_ = min_vertex_degree
326
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(max_vertex_degree.lt.min_vertex_degree_)then
327 write(0,*) "Error: max_vertex_degree < min_vertex_degree"
328 return
329 end if
330
331
332 !---------------------------------------------------------------------------
333 ! Define weights (kernels) and biases initialisers
334 !---------------------------------------------------------------------------
335
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 5 times.
6 if(present(kernel_initialiser))then
336
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 kernel_initialiser_ = initialiser_setup(kernel_initialiser)
337 end if
338
339
340 !---------------------------------------------------------------------------
341 ! Set hyperparameters
342 !---------------------------------------------------------------------------
343 call layer%set_hyperparams( &
344 num_vertex_features = num_vertex_features, &
345 num_edge_features = num_edge_features, &
346 min_vertex_degree = min_vertex_degree_, &
347 max_vertex_degree = max_vertex_degree, &
348 num_time_steps = num_time_steps, &
349 num_outputs = num_outputs, &
350 message_activation = message_activation_, &
351 readout_activation = readout_activation_, &
352 kernel_initialiser = kernel_initialiser_, &
353 verbose = verbose_ &
354
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 6 times.
6 )
355
356
357 !---------------------------------------------------------------------------
358 ! Initialise layer shape
359 !---------------------------------------------------------------------------
360 call layer%init(input_shape=[ &
361 layer%num_vertex_features(0), &
362 layer%num_edge_features(0) &
363
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✓ Branch 8 taken 12 times.
✓ Branch 9 taken 6 times.
18 ])
364
365
8/14
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
12 end function layer_setup
366 !###############################################################################
367
368
369 !###############################################################################
370 6 subroutine set_hyperparams_duvenaud( &
371 this, &
372
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
6 num_vertex_features, num_edge_features, &
373 min_vertex_degree, &
374 max_vertex_degree, &
375 num_time_steps, &
376 num_outputs, &
377 message_activation, &
378 readout_activation, &
379 kernel_initialiser, &
380 verbose &
381 )
382 !! Set the hyperparameters for the message passing layer
383 use athena__activation, only: activation_setup
384 use athena__initialiser, only: get_default_initialiser, initialiser_setup
385 implicit none
386
387 ! Arguments
388 class(duvenaud_msgpass_layer_type), intent(inout) :: this
389 !! Instance of the message passing layer
390 integer, dimension(:), intent(in) :: num_vertex_features
391 !! Number of vertex features
392 integer, dimension(:), intent(in) :: num_edge_features
393 !! Number of edge features
394 integer, intent(in) :: min_vertex_degree
395 !! Minimum vertex degree
396 integer, intent(in) :: max_vertex_degree
397 !! Maximum vertex degree
398 integer, intent(in) :: num_time_steps
399 !! Number of time steps
400 integer, intent(in) :: num_outputs
401 !! Number of outputs
402 class(base_actv_type), allocatable, intent(in) :: &
403 message_activation, &
404 readout_activation
405 !! Message and readout activation functions
406 class(base_init_type), allocatable, intent(in) :: kernel_initialiser
407 !! Kernel and bias initialisers
408 integer, optional, intent(in) :: verbose
409 !! Verbosity level
410
411 ! Local variables
412 integer :: t
413 !! Loop index
414 character(len=256) :: buffer
415
416
417
3/8
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
6 this%name = 'duvenaud'
418 6 this%type = 'msgp'
419 6 this%input_rank = 2
420 6 this%output_rank = 1
421 6 this%min_vertex_degree = min_vertex_degree
422 6 this%max_vertex_degree = max_vertex_degree
423 6 this%num_time_steps = num_time_steps
424 6 this%num_outputs = num_outputs
425
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(allocated(this%num_vertex_features)) &
426 deallocate(this%num_vertex_features)
427
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(allocated(this%num_edge_features)) &
428 deallocate(this%num_edge_features)
429
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✓ Branch 9 taken 4 times.
✓ Branch 10 taken 2 times.
6 if(size(num_vertex_features, 1) .eq. 1)then
430 allocate( &
431 this%num_vertex_features(0:num_time_steps), &
432
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 source = num_vertex_features(1) &
433
14/26
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 4 times.
✓ Branch 30 taken 11 times.
✓ Branch 31 taken 4 times.
15 )
434
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
2 elseif(size(num_vertex_features, 1) .eq. num_time_steps + 1)then
435 allocate( &
436 this%num_vertex_features(0:this%num_time_steps), &
437 source = num_vertex_features &
438
20/38
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 2 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 2 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 2 times.
✓ Branch 43 taken 6 times.
✓ Branch 44 taken 2 times.
8 )
439 else
440 write(*,*) "Error: num_vertex_features must be a scalar or a vector of &
441 &length num_time_steps + 1"
442 stop
443 end if
444
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✓ Branch 9 taken 4 times.
✓ Branch 10 taken 2 times.
6 if(size(num_edge_features, 1) .eq. 1)then
445 allocate( &
446 this%num_edge_features(0:num_time_steps), &
447
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 source = num_edge_features(1) &
448
14/26
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 4 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 4 times.
✓ Branch 30 taken 11 times.
✓ Branch 31 taken 4 times.
15 )
449
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
2 elseif(size(num_edge_features, 1) .eq. num_time_steps + 1)then
450 allocate( &
451 this%num_edge_features(0:this%num_time_steps), &
452 source = num_edge_features &
453
20/38
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 2 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 2 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 2 times.
✓ Branch 43 taken 6 times.
✓ Branch 44 taken 2 times.
8 )
454 else
455 write(*,*) "Error: num_edge_features must be a scalar or a vector of &
456 &length num_time_steps + 1"
457 stop
458 end if
459 6 this%use_graph_input = .true.
460 6 this%use_graph_output = .false.
461
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
6 if(allocated(this%activation)) deallocate(this%activation)
462
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
6 if(allocated(this%activation_readout)) deallocate(this%activation_readout)
463
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(.not.allocated(message_activation))then
464 this%activation = activation_setup(default_message_actv_name)
465 else
466
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 allocate( this%activation, source=message_activation )
467 end if
468
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(.not.allocated(readout_activation))then
469 this%activation_readout = activation_setup(default_readout_actv_name)
470 else
471
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 allocate(this%activation_readout, source=readout_activation)
472 end if
473
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
6 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
474
2/2
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 1 times.
6 if(.not.allocated(kernel_initialiser))then
475
1/2
✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
5 buffer = get_default_initialiser(this%activation%name)
476
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✓ Branch 17 taken 5 times.
✗ Branch 18 not taken.
5 this%kernel_init = initialiser_setup(buffer)
477 else
478
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(this%kernel_init, source=kernel_initialiser)
479 end if
480
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 if(present(verbose))then
481
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(abs(verbose).gt.0)then
482 write(*,'("DUVENAUD message activation function: ",A)') &
483 trim(this%activation%name)
484 write(*,'("DUVENAUD readout activation function: ",A)') &
485 trim(this%activation_readout%name)
486 write(*,'("DUVENAUD kernel initialiser: ",A)') &
487 trim(this%kernel_init%name)
488 end if
489 end if
490
491
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
6 if(allocated(this%num_params_msg)) deallocate(this%num_params_msg)
492
7/14
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 6 times.
6 allocate(this%num_params_msg(1:this%num_time_steps))
493
2/2
✓ Branch 0 taken 11 times.
✓ Branch 1 taken 6 times.
17 do t = 1, this%num_time_steps
494 22 this%num_params_msg(t) = &
495 44 ( this%num_vertex_features(t-1) + this%num_edge_features(0) ) * &
496 22 this%num_vertex_features(t) * &
497
8/16
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
17 ( this%max_vertex_degree - this%min_vertex_degree + 1 )
498 end do
499 this%num_params_readout = &
500
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✓ Branch 12 taken 17 times.
✓ Branch 13 taken 6 times.
23 sum( this%num_vertex_features * this%num_outputs )
501
502
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
6 if(allocated(this%input_shape)) deallocate(this%input_shape)
503
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
6 if(allocated(this%output_shape)) deallocate(this%output_shape)
504
505 6 end subroutine set_hyperparams_duvenaud
506 !###############################################################################
507
508
509 !###############################################################################
510
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
6 subroutine init_duvenaud(this, input_shape, verbose)
511 !! Initialise the message passing layer
512 use athena__initialiser, only: initialiser_setup
513 implicit none
514
515 ! Arguments
516 class(duvenaud_msgpass_layer_type), intent(inout) :: this
517 !! Instance of the fully connected layer
518 integer, dimension(:), intent(in) :: input_shape
519 !! Input shape
520 integer, optional, intent(in) :: verbose
521 !! Verbosity level
522
523 ! Local variables
524 integer :: t
525 !! Loop index
526 integer :: verbose_ = 0
527 !! Verbosity level
528
529
530 !---------------------------------------------------------------------------
531 ! Initialise optional arguments
532 !---------------------------------------------------------------------------
533
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
6 if(present(verbose)) verbose_ = verbose
534
535
536 !---------------------------------------------------------------------------
537 ! Initialise number of inputs
538 !---------------------------------------------------------------------------
539
10/16
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✓ Branch 14 taken 12 times.
✓ Branch 15 taken 6 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✓ Branch 18 taken 12 times.
✓ Branch 19 taken 6 times.
30 if(.not.allocated(this%input_shape)) call this%set_shape([input_shape])
540
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✓ Branch 7 taken 6 times.
12 this%output_shape = [this%num_outputs]
541 6 this%num_params = this%get_num_params()
542
543
544 !---------------------------------------------------------------------------
545 ! Allocate weight, weight steps (velocities), output, and activation
546 !---------------------------------------------------------------------------
547
8/16
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
6 allocate(this%weight_shape(3,2*this%num_time_steps))
548
21/40
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 6 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 6 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 6 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 6 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 6 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 6 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 6 times.
✓ Branch 31 taken 22 times.
✓ Branch 32 taken 6 times.
✓ Branch 33 taken 22 times.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 22 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 22 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 22 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 22 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 22 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 22 times.
28 allocate(this%params(this%num_time_steps*2))
549
2/2
✓ Branch 0 taken 11 times.
✓ Branch 1 taken 6 times.
17 do t = 1, this%num_time_steps
550 99 this%weight_shape(:,t) = [ &
551 22 this%num_vertex_features(t), &
552 22 this%num_vertex_features(t-1) + this%num_edge_features(0), &
553 this%max_vertex_degree - this%min_vertex_degree + 1 &
554
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 11 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 11 times.
✓ Branch 39 taken 33 times.
✓ Branch 40 taken 11 times.
44 ]
555 77 this%weight_shape(:,t+this%num_time_steps) = &
556
11/20
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✓ Branch 27 taken 33 times.
✓ Branch 28 taken 11 times.
44 [ this%num_outputs, this%num_vertex_features(t), 1 ]
557
14/24
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✓ Branch 27 taken 33 times.
✓ Branch 28 taken 11 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11 times.
✓ Branch 31 taken 44 times.
✓ Branch 32 taken 11 times.
88 call this%params(t)%allocate( [ this%weight_shape(:,t), 1 ] )
558 22 call this%params(t+this%num_time_steps)%allocate( &
559 66 [ this%weight_shape(:2,t+this%num_time_steps), 1 ] &
560
14/24
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✓ Branch 27 taken 22 times.
✓ Branch 28 taken 11 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11 times.
✓ Branch 31 taken 33 times.
✓ Branch 32 taken 11 times.
66 )
561
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 call this%params(t)%set_requires_grad(.true.)
562
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 this%params(t)%fix_pointer = .true.
563
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 this%params(t)%is_temporary = .false.
564
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 this%params(t)%is_sample_dependent = .false.
565
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 22 times.
✓ Branch 13 taken 11 times.
33 this%params(t)%indices = [ this%min_vertex_degree, this%max_vertex_degree ]
566
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 call this%params(t+this%num_time_steps)%set_requires_grad(.true.)
567
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 this%params(t+this%num_time_steps)%fix_pointer = .true.
568
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
11 this%params(t+this%num_time_steps)%is_temporary = .false.
569
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
17 this%params(t+this%num_time_steps)%is_sample_dependent = .false.
570 end do
571
572
573 !---------------------------------------------------------------------------
574 ! Initialise weights (kernels)
575 !---------------------------------------------------------------------------
576
2/2
✓ Branch 0 taken 11 times.
✓ Branch 1 taken 6 times.
17 do t = 1, this%num_time_steps, 1
577 call this%kernel_init%initialise( &
578 110 this%params(t)%val(:,1), &
579 44 fan_in = this%num_vertex_features(t-1) + this%num_edge_features(0), &
580 22 fan_out = this%num_vertex_features(t), &
581 22 spacing = [ this%num_vertex_features(t-1) ] &
582
20/38
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 11 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 11 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 11 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 11 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 11 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 11 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 11 times.
✓ Branch 54 taken 11 times.
✓ Branch 55 taken 11 times.
22 )
583 call this%kernel_init%initialise( &
584 110 this%params(t+this%num_time_steps)%val(:,1), &
585 44 fan_in = sum(this%num_vertex_features), &
586 fan_out = this%num_outputs, &
587 spacing = this%num_vertex_features &
588
16/30
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 11 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 11 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 11 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 11 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 11 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 11 times.
✓ Branch 42 taken 34 times.
✓ Branch 43 taken 11 times.
51 )
589 end do
590
591
592 !---------------------------------------------------------------------------
593 ! Allocate arrays
594 !---------------------------------------------------------------------------
595
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
6 if(allocated(this%output)) deallocate(this%output)
596
15/26
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✓ Branch 33 taken 6 times.
✓ Branch 34 taken 6 times.
✓ Branch 35 taken 6 times.
✓ Branch 36 taken 6 times.
18 allocate(this%output(1,1))
597
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
6 if(allocated(this%z)) deallocate(this%z)
598
599 6 end subroutine init_duvenaud
600 !###############################################################################
601
602
603 !##############################################################################!
604
1/2
✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
10 subroutine set_graph_duvenaud(this, graph)
605 !! Set the graph structure of the input data
606 implicit none
607
608 ! Arguments
609 class(duvenaud_msgpass_layer_type), intent(inout) :: this
610 !! Instance of the layer
611 type(graph_type), dimension(:), intent(in) :: graph
612 !! Graph structure of input data
613
614 ! Local variables
615 integer :: s, t
616 !! Loop indices
617
618
2/2
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 5 times.
10 if(allocated(this%graph))then
619
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
5 if(size(this%graph).ne.size(graph))then
620 deallocate(this%graph)
621 allocate(this%graph(size(graph)))
622 end if
623 else
624
27/60
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✓ Branch 9 taken 5 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 5 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 5 times.
✓ Branch 38 taken 5 times.
✓ Branch 39 taken 5 times.
✓ Branch 40 taken 5 times.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 5 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 5 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 5 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 5 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 5 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 5 times.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✓ Branch 65 taken 5 times.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✗ Branch 68 not taken.
✗ Branch 69 not taken.
10 allocate(this%graph(size(graph)))
625 end if
626
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✓ Branch 9 taken 10 times.
✓ Branch 10 taken 10 times.
20 do s = 1, size(graph)
627
16/28
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✓ Branch 24 taken 5 times.
✓ Branch 25 taken 5 times.
✓ Branch 26 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 58 times.
✓ Branch 35 taken 10 times.
68 this%graph(s)%adj_ia = graph(s)%adj_ia
628
25/44
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 10 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 10 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 10 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 10 times.
✓ Branch 36 taken 5 times.
✓ Branch 37 taken 5 times.
✓ Branch 38 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 5 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 5 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 118 times.
✓ Branch 53 taken 10 times.
✓ Branch 54 taken 236 times.
✓ Branch 55 taken 118 times.
364 this%graph(s)%adj_ja = graph(s)%adj_ja
629
16/28
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✓ Branch 24 taken 5 times.
✓ Branch 25 taken 5 times.
✓ Branch 26 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 59 times.
✓ Branch 35 taken 10 times.
69 this%graph(s)%edge_weights = graph(s)%edge_weights
630
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
10 this%graph(s)%num_edges = graph(s)%num_edges
631
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
10 this%graph(s)%num_vertices = graph(s)%num_vertices
632
14/26
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 10 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 10 times.
✓ Branch 30 taken 118 times.
✓ Branch 31 taken 10 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 118 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 10 times.
138 if(any(this%graph(s)%adj_ja(1,:).gt.this%graph(s)%num_vertices))then
633 write(*,*) "Error: graph adjacency matrix has indices greater than &
634 &the number of vertices", s, &
635 this%graph(s)%num_vertices
636 write(*,*) "Adjacency matrix indices: ", this%graph(s)%adj_ja
637 stop
638 end if
639 end do
640
641 10 end subroutine set_graph_duvenaud
642 !##############################################################################!
643
644
645 !##############################################################################!
646 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
647 !##############################################################################!
648
649
650 !###############################################################################
651 1 subroutine print_to_unit_duvenaud(this, unit)
652 !! Print kipf message passing layer to unit
653 use coreutils, only: to_upper
654 implicit none
655
656 ! Arguments
657 class(duvenaud_msgpass_layer_type), intent(in) :: this
658 !! Instance of the message passing layer
659 integer, intent(in) :: unit
660 !! Filename
661
662 ! Local variables
663 integer :: t
664 !! Loop index
665 character(100) :: fmt
666 !! Format string
667
668
669 ! Write initial parameters
670 !---------------------------------------------------------------------------
671 1 write(unit,'(3X,"NUM_TIME_STEPS = ",I0)') this%num_time_steps
672 write(fmt,'("(3X,""NUM_VERTEX_FEATURES ="",",I0,"(1X,I0))")') &
673 1 this%num_time_steps + 1
674 1 write(unit,fmt) this%num_vertex_features
675 write(fmt,'("(3X,""NUM_EDGE_FEATURES ="",",I0,"(1X,I0))")') &
676 1 this%num_time_steps + 1
677 1 write(unit,fmt) this%num_edge_features
678
679
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%activation%name .ne. 'none')then
680 1 call this%activation%print_to_unit(unit, identifier='MESSAGE')
681 end if
682
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%activation_readout%name .ne. 'none')then
683 1 call this%activation_readout%print_to_unit(unit, identifier='READOUT')
684 end if
685
686
687 ! Write learned parameters
688 !---------------------------------------------------------------------------
689 1 write(unit,'("WEIGHTS")')
690
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
4 do t = 1, this%num_time_steps, 1
691
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 3 times.
✓ Branch 25 taken 1980 times.
✓ Branch 26 taken 3 times.
1984 write(unit,'(5(E16.8E2))') this%params(t)%val(:,1)
692 end do
693
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
4 do t = 1, this%num_time_steps, 1
694
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 3 times.
✓ Branch 25 taken 150 times.
✓ Branch 26 taken 3 times.
154 write(unit,'(5(E16.8E2))') this%params(t+this%num_time_steps)%val(:,1)
695 end do
696 1 write(unit,'("END WEIGHTS")')
697
698 1 end subroutine print_to_unit_duvenaud
699 !###############################################################################
700
701
702 !###############################################################################
703 1 subroutine read_duvenaud(this, unit, verbose)
704 !! Read the message passing layer
705 implicit none
706
707 ! Arguments
708 class(duvenaud_msgpass_layer_type), intent(inout) :: this
709 !! Instance of the message passing layer
710 integer, intent(in) :: unit
711 !! Unit to read from
712 integer, optional, intent(in) :: verbose
713 !! Verbosity level
714 1 end subroutine read_duvenaud
715 !###############################################################################
716
717
718 !###############################################################################
719 1 function read_duvenaud_msgpass_layer(unit, verbose) result(layer)
720 !! Read duvenaud message passing layer from file and return layer
721 implicit none
722
723 ! Arguments
724 integer, intent(in) :: unit
725 !! Unit number
726 integer, optional, intent(in) :: verbose
727 !! Verbosity level
728 class(base_layer_type), allocatable :: layer
729 !! Instance of the message passing layer
730
731 ! Local variables
732 integer :: verbose_ = 0
733 !! Verbosity level
734
735
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(verbose)) verbose_ = verbose
736 allocate(layer, source = duvenaud_msgpass_layer_type( &
737 num_vertex_features = [ 0 ], &
738 num_edge_features = [ 0 ], &
739 num_time_steps = 1, &
740 max_vertex_degree = 1, &
741 num_outputs = 1 &
742
30/98
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✓ Branch 59 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 62 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✓ Branch 69 taken 1 times.
✗ Branch 70 not taken.
✗ Branch 71 not taken.
✓ Branch 72 taken 1 times.
✗ Branch 73 not taken.
✗ Branch 74 not taken.
✗ Branch 76 not taken.
✓ Branch 77 taken 1 times.
✓ Branch 78 taken 1 times.
✗ Branch 79 not taken.
✗ Branch 80 not taken.
✓ Branch 81 taken 1 times.
✓ Branch 83 taken 1 times.
✗ Branch 84 not taken.
✓ Branch 85 taken 1 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 1 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 1 times.
✗ Branch 90 not taken.
✓ Branch 91 taken 1 times.
✗ Branch 92 not taken.
✗ Branch 93 not taken.
✓ Branch 94 taken 1 times.
✓ Branch 96 taken 1 times.
✗ Branch 97 not taken.
✗ Branch 98 not taken.
✓ Branch 99 taken 1 times.
✗ Branch 101 not taken.
✓ Branch 102 taken 1 times.
✗ Branch 103 not taken.
✓ Branch 104 taken 1 times.
✗ Branch 106 not taken.
✓ Branch 107 taken 1 times.
2 ))
743 1 call layer%read(unit, verbose=verbose_)
744
745 2 end function read_duvenaud_msgpass_layer
746 !###############################################################################
747
748
749 !##############################################################################!
750 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
751 !##############################################################################!
752
753
754 !##############################################################################!
755
1/2
✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
9 subroutine update_message_duvenaud(this, input)
756 !! Update the message
757 implicit none
758
759 ! Arguments
760 class(duvenaud_msgpass_layer_type), intent(inout), target :: this
761 !! Instance of the message passing layer
762 class(array_type), dimension(:,:), intent(in), target :: input
763 !! Input to the message passing layer
764
765 ! Local variables
766 integer :: s, t
767 !! Batch index, time step
768 logical :: has_activation
769 type(array_type), pointer :: ptr1, ptr2, ptr3, ptr_edge, ptr_params
770 !! Pointers to arrays
771
772
773
2/2
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 4 times.
9 if(allocated(this%z))then
774
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
5 if(size(this%z,2).ne.size(input,2))then
775 deallocate(this%z)
776 allocate(this%z(this%num_time_steps,size(input,2)))
777 end if
778 else
779
34/64
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
✓ Branch 22 taken 4 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 4 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✓ Branch 63 taken 4 times.
✓ Branch 64 taken 4 times.
✓ Branch 65 taken 7 times.
✓ Branch 66 taken 4 times.
✓ Branch 67 taken 7 times.
✗ Branch 68 not taken.
✗ Branch 69 not taken.
✓ Branch 70 taken 7 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 7 times.
✗ Branch 73 not taken.
✓ Branch 74 taken 7 times.
✗ Branch 75 not taken.
✓ Branch 76 taken 7 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 7 times.
✗ Branch 79 not taken.
✓ Branch 80 taken 7 times.
15 allocate(this%z(this%num_time_steps,size(input,2)))
780 end if
781
782
783
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
9 if(.not.allocated(this%activation))then
784 has_activation = .false.
785 else
786
2/4
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
9 if(trim(this%activation%name).eq."none")then
787 has_activation = .true.
788 else
789 9 has_activation = .true.
790 end if
791 end if
792
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 9 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 9 times.
✓ Branch 18 taken 9 times.
✓ Branch 19 taken 9 times.
18 do s = 1, size(input,2)
793
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
9 ptr1 => input(1,s)
794
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
9 ptr_edge => input(2,s)
795
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 9 times.
30 do t = 1, this%num_time_steps
796 ptr2 => duvenaud_propagate( &
797 ptr1, ptr_edge, &
798 48 this%graph(s)%adj_ia, this%graph(s)%adj_ja &
799
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
12 )
800
801
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
12 ptr_params => this%params(t)
802 ptr3 => duvenaud_update( &
803 ptr2, ptr_params, &
804 24 this%graph(s)%adj_ia, &
805 this%min_vertex_degree, this%max_vertex_degree &
806
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
12 )
807
1/2
✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
12 if(has_activation)then
808 12 ptr3 => this%activation%apply( ptr3 )
809 end if
810
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
12 call this%z(t,s)%zero_grad()
811
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
12 call this%z(t,s)%assign_and_deallocate_source(ptr3)
812
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
12 this%z(t,s)%is_temporary = .false.
813
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
21 ptr1 => this%z(t,s)
814 end do
815 end do
816
817 9 end subroutine update_message_duvenaud
818 !###############################################################################
819
820
821 !##############################################################################!
822 9 subroutine update_readout_duvenaud(this)
823 !! Update the readout
824 implicit none
825
826 ! Arguments
827 class(duvenaud_msgpass_layer_type), intent(inout), target :: this
828 !! Instance of the message passing layer
829
830 ! Local variables
831 integer :: s, t, batch_size
832 !! Loop indices
833 type(array_type), pointer :: ptr1, ptr2, ptr3, ptr_params, ptr_z
834
835
836 9 batch_size = size(this%z,2)
837
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
9 call this%output(1,1)%zero_grad()
838
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 9 times.
21 do t = 1, this%num_time_steps, 1
839
2/2
✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
33 do s = 1, batch_size, 1
840
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
12 ptr_params => this%params(t+this%num_time_steps)
841
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
12 ptr_z => this%z(t,s)
842 ptr1 => matmul( &
843 ptr_params, &
844 ptr_z &
845 12 )
846 12 ptr2 => this%activation_readout%apply( ptr1 )
847
2/2
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 3 times.
24 if(t.eq.1.and.s.eq.1)then
848 ptr3 => &
849 9 sum( ptr2, dim = 2, new_dim_index=s, new_dim_size=batch_size )
850 else
851 ptr3 => ptr3 + &
852 3 sum( ptr2, dim = 2, new_dim_index=s, new_dim_size=batch_size )
853 end if
854 end do
855 end do
856
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
9 call this%output(1,1)%assign_and_deallocate_source(ptr3)
857
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
9 this%output(1,1)%is_temporary = .false.
858
859 9 end subroutine update_readout_duvenaud
860 !###############################################################################
861
862
863 !###############################################################################
864 1 subroutine emit_onnx_nodes_duvenaud( &
865 this, prefix, &
866 1 nodes, num_nodes, max_nodes, &
867
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 inits, num_inits, max_inits, &
868 input_name, is_last_layer, format &
869 )
870 !! Emit ONNX JSON nodes for Duvenaud GNN layer
871 !!
872 !! Decomposes the Duvenaud message passing layer into standard ONNX ops:
873 !! Gather, Concat, ScatterElements, MatMul, Sigmoid/activation,
874 !! Softmax, ReduceSum, Add, Div, Clip, Sub, etc.
875 !!
876 !! This override is called by write_onnx instead of the standard
877 !! node emission logic, making the ONNX export extensible for new
878 !! GNN layer types.
879 use athena__onnx_msgpass_utils, only: emit_output_identity
880 implicit none
881
882 ! Arguments
883 class(duvenaud_msgpass_layer_type), intent(in) :: this
884 !! Instance of the layer
885 character(*), intent(in) :: prefix
886 !! Node name prefix (e.g. "node_2")
887 type(onnx_node_type), intent(inout), dimension(:) :: nodes
888 !! Accumulator for ONNX nodes
889 integer, intent(inout) :: num_nodes
890 !! Current number of nodes
891 integer, intent(in) :: max_nodes
892 !! Maximum capacity
893 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
894 !! Accumulator for ONNX initialisers
895 integer, intent(inout) :: num_inits
896 !! Current number of initialisers
897 integer, intent(in) :: max_inits
898 !! Maximum capacity
899 character(*), optional, intent(in) :: input_name
900 !! Unused sequential input name
901 logical, optional, intent(in) :: is_last_layer
902 !! Unused last-layer flag
903 integer, optional, intent(in) :: format
904 !! Unused export format selector
905
906 ! Local variables
907 integer :: t
908 character(128) :: cur_vertex_name, readout_accum
909
910 ! Must be called with vertex_input, edge_input etc. already set
911 ! These are stored in the node's input naming convention
912 ! prefix is e.g. "node_2", inputs come from the calling context
913
914 ! ===== Emit message passing time steps =====
915
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
3 do t = 1, this%num_time_steps
916 call emit_duvenaud_timestep( &
917 prefix, t, &
918 8 this%num_vertex_features(t-1), this%num_edge_features(0), &
919 4 this%num_vertex_features(t), &
920 this%min_vertex_degree, this%max_vertex_degree, &
921 20 this%params(t)%val(:,1), &
922 this%activation%name, &
923 nodes, num_nodes, max_nodes, &
924 inits, num_inits, max_inits, &
925 cur_vertex_name &
926
22/44
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 2 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 2 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 2 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 2 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 2 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 2 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 2 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 2 times.
3 )
927 end do
928
929 ! ===== Emit readout =====
930 call emit_duvenaud_readout_impl( &
931 prefix, this, &
932 nodes, num_nodes, max_nodes, &
933 inits, num_inits, max_inits, &
934 readout_accum &
935
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
1 )
936
937 ! The readout output becomes the layer output for downstream layers.
938 call emit_output_identity( &
939 prefix, trim(readout_accum), this%activation%name, &
940
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
1 nodes, num_nodes)
941
942
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 end subroutine emit_onnx_nodes_duvenaud
943 !###############################################################################
944
945
946 !###############################################################################
947 subroutine emit_onnx_graph_inputs_duvenaud( &
948 this, prefix, &
949 graph_inputs, num_inputs &
950 )
951 !! Emit graph input tensor declarations for Duvenaud GNN layer
952 !!
953 !! Adds: vertex features, edge features, edge_index [3, ncsr], degree
954 use athena__onnx_msgpass_utils, only: emit_msgpass_graph_inputs
955 implicit none
956
957 ! Arguments
958 class(duvenaud_msgpass_layer_type), intent(in) :: this
959 !! Instance of the layer
960 character(*), intent(in) :: prefix
961 !! Input name prefix (e.g. "input_1")
962 type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
963 !! Accumulator for graph input tensor declarations
964 integer, intent(inout) :: num_inputs
965 !! Current number of graph input declarations
966
967 call emit_msgpass_graph_inputs( &
968 prefix, this%input_shape, graph_inputs, num_inputs)
969
970 end subroutine emit_onnx_graph_inputs_duvenaud
971 !###############################################################################
972
973
974 !###############################################################################
975 2 subroutine emit_duvenaud_timestep( &
976 prefix, t, nv_in, ne_in, nv_out, &
977 2 min_degree, max_degree, weight_data, activation_name, &
978 2 nodes, num_nodes, max_nodes, &
979
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
4 inits, num_inits, max_inits, vertex_out)
980 !! Emit ONNX nodes for one Duvenaud message passing time step.
981 use athena__onnx_utils, only: emit_node, emit_constant_int64, &
982 emit_activation_node
983 use athena__onnx_msgpass_utils, only: get_timestep_output_name, &
984 emit_edge_index_component, emit_scatter_aggregator
985 implicit none
986
987 ! Arguments
988 character(*), intent(in) :: prefix
989 integer, intent(in) :: t
990 integer, intent(in) :: nv_in, ne_in, nv_out
991 integer, intent(in) :: min_degree, max_degree
992 real(real32), intent(in) :: weight_data(:)
993 character(*), intent(in) :: activation_name
994 type(onnx_node_type), intent(inout), dimension(:) :: nodes
995 integer, intent(inout) :: num_nodes
996 integer, intent(in) :: max_nodes
997 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
998 integer, intent(inout) :: num_inits
999 integer, intent(in) :: max_inits
1000 character(128), intent(out) :: vertex_out
1001
1002 ! Local variables
1003 character(128) :: tp, tmp1, tmp2, tmp3
1004 character(128) :: vertex_in, edge_in, edge_index_in, degree_in
1005 character(128) :: src_idx, edge_idx, target_idx
1006 character(128) :: msg_name, aggr_name, sq_out
1007 character(len=*), parameter :: onnx_axis0_attr = &
1008 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
1009 character(len=*), parameter :: onnx_concat_axis1_attr = &
1010 ' "attribute": [{"name": "axis", "i": "1", "type": "INT"}]'
1011
1012
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tp, '(A,"_t",I0)') trim(prefix), t
1013
1014 ! Input tensor names follow the convention set during write_onnx.
1015 ! For t=1 the vertex input comes from the previous layer, while edge,
1016 ! edge_index, and degree are always rooted at the original graph input.
1017
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(vertex_in, '(A,"_vertex_in")') trim(prefix)
1018
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(edge_in, '(A,"_edge_in")') trim(prefix)
1019
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(edge_index_in, '(A,"_edge_index_in")') trim(prefix)
1020
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(degree_in, '(A,"_degree_in")') trim(prefix)
1021
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 if(t .gt. 1)then
1022 call get_timestep_output_name( &
1023 1 prefix, t-1, activation_name, '_sq_out', '_sq', vertex_in)
1024 end if
1025
1026 ! --- Step 1: Extract source and edge-feature indices from edge_index ---
1027
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tmp1, '(A,"_idx0")') trim(tp)
1028 call emit_constant_int64(trim(tmp1), [0], [1], &
1029
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1030
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tmp2, '(A,"_idx1")') trim(tp)
1031 call emit_constant_int64(trim(tmp2), [1], [1], &
1032
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1033
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tmp3, '(A,"_idx2")') trim(tp)
1034 call emit_constant_int64(trim(tmp3), [2], [1], &
1035
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1036
1037 call emit_edge_index_component( &
1038
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
2 tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes)
1039 call emit_edge_index_component( &
1040
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
2 tp, edge_index_in, trim(tmp2), 'eidx', edge_idx, nodes, num_nodes)
1041 call emit_edge_index_component( &
1042
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
2 tp, edge_index_in, trim(tmp3), 'tgt', target_idx, nodes, num_nodes)
1043
1044 ! --- Step 2: Gather source vertex features and edge features ---
1045
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tmp1, '(A,"_src_feat")') trim(tp)
1046 call emit_node('Gather', trim(tp)//'_gather_vfeat', &
1047 trim(tmp1), onnx_axis0_attr, nodes, num_nodes, &
1048
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(vertex_in), in2=trim(src_idx))
1049
1050
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tmp2, '(A,"_edge_feat")') trim(tp)
1051 call emit_node('Gather', trim(tp)//'_gather_efeat', &
1052 trim(tmp2), onnx_axis0_attr, nodes, num_nodes, &
1053
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(edge_in), in2=trim(edge_idx))
1054
1055 ! --- Step 3: Concat source vertex + edge features ---
1056
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(msg_name, '(A,"_msg")') trim(tp)
1057 call emit_node('Concat', trim(tp)//'_concat_msg', &
1058 trim(msg_name), onnx_concat_axis1_attr, nodes, num_nodes, &
1059
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(tmp1), in2=trim(tmp2))
1060
1061 ! --- Step 4: Scatter-add to aggregate messages per target vertex ---
1062 call emit_scatter_aggregator( &
1063 tp, vertex_in, target_idx, msg_name, nv_in + ne_in, &
1064
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
2 nodes, num_nodes, inits, num_inits, aggr_name)
1065
1066 ! --- Step 5: Degree-specific weight application ---
1067 call emit_duvenaud_degree_update( &
1068 tp, degree_in, min_degree, max_degree, nv_in + ne_in, nv_out, &
1069
9/18
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
2 weight_data, aggr_name, nodes, num_nodes, inits, num_inits, sq_out)
1070
1071 ! --- Step 6: Activation ---
1072
2/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 if(trim(activation_name) .ne. 'none')then
1073 call emit_activation_node(activation_name, trim(tp)//'_sq', &
1074
5/10
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 trim(sq_out), nodes, num_nodes, max_nodes)
1075
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 vertex_out = trim(nodes(num_nodes)%outputs(1))
1076 else
1077 vertex_out = trim(sq_out)
1078 end if
1079
1080
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 end subroutine emit_duvenaud_timestep
1081 !###############################################################################
1082
1083
1084 !###############################################################################
1085 2 subroutine emit_duvenaud_degree_update( &
1086 tp, degree_in, min_degree, max_degree, feature_dim, nv_out, &
1087
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
2 weight_data, aggr_in, nodes, num_nodes, inits, num_inits, sq_out)
1088 !! Emit the degree-dependent weight selection and update block.
1089 use athena__onnx_utils, only: emit_node, emit_squeeze_node, &
1090 emit_constant_int64, emit_constant_float
1091 use athena__onnx_msgpass_utils, only: emit_weight_initialiser_3d
1092 implicit none
1093
1094 ! Arguments
1095 character(*), intent(in) :: tp, degree_in, aggr_in
1096 integer, intent(in) :: min_degree, max_degree, feature_dim, nv_out
1097 real(real32), intent(in) :: weight_data(:)
1098 type(onnx_node_type), intent(inout), dimension(:) :: nodes
1099 integer, intent(inout) :: num_nodes
1100 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
1101 integer, intent(inout) :: num_inits
1102 character(128), intent(out) :: sq_out
1103
1104 ! Local variables
1105 character(128) :: min_deg_name, max_deg_name, deg_float
1106 character(128) :: deg_clip, deg_idx_float, deg_idx
1107 character(128) :: weight_name, weight_sel, deg_us
1108 character(128) :: aggr_norm, aggr_us, matmul_out
1109 character(128) :: axes1_name, axes2_name
1110 character(len=*), parameter :: onnx_axis0_attr = &
1111 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
1112 character(len=*), parameter :: onnx_cast_float_attr = &
1113 ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]'
1114 character(len=*), parameter :: onnx_cast_int64_attr = &
1115 ' "attribute": [{"name": "to", "i": "7", "type": "INT"}]'
1116
1117 ! Clip degree to the supported bucket interval.
1118
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(min_deg_name, '(A,"_min_deg")') trim(tp)
1119 call emit_constant_float(trim(min_deg_name), &
1120 [ real(min_degree, real32) ], [1], &
1121
9/16
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
4 nodes, num_nodes, inits, num_inits)
1122
1123
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(max_deg_name, '(A,"_max_deg")') trim(tp)
1124 call emit_constant_float(trim(max_deg_name), &
1125 [ real(max_degree, real32) ], [1], &
1126
9/16
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
4 nodes, num_nodes, inits, num_inits)
1127
1128
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(deg_float, '(A,"_deg_f")') trim(tp)
1129 call emit_node('Cast', trim(tp)//'_cast_deg', &
1130 trim(deg_float), onnx_cast_float_attr, nodes, num_nodes, &
1131
6/12
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 in1=trim(degree_in))
1132
1133
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(deg_clip, '(A,"_deg_clip")') trim(tp)
1134 call emit_node('Clip', trim(tp)//'_clip_deg', &
1135 trim(deg_clip), '', nodes, num_nodes, &
1136
8/16
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 22 not taken.
2 in1=trim(deg_float), in2=trim(min_deg_name), in3=trim(max_deg_name))
1137
1138 ! Shift clipped degrees so they can index the weight bank from zero.
1139
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(deg_idx_float, '(A,"_deg_idx_f")') trim(tp)
1140 call emit_node('Sub', trim(tp)//'_sub_mindeg', &
1141 trim(deg_idx_float), '', nodes, num_nodes, &
1142
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(deg_clip), in2=trim(min_deg_name))
1143
1144
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(deg_idx, '(A,"_deg_idx")') trim(tp)
1145 call emit_node('Cast', trim(tp)//'_cast_degidx', &
1146 trim(deg_idx), onnx_cast_int64_attr, nodes, num_nodes, &
1147
6/12
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 in1=trim(deg_idx_float))
1148
1149 ! Store the degree-specific weight bank as a 3D initialiser.
1150
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(weight_name, '(A,"_W")') trim(tp)
1151 call emit_weight_initialiser_3d( &
1152 trim(weight_name), max_degree - min_degree + 1, &
1153
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nv_out, feature_dim, weight_data, inits, num_inits)
1154
1155
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(weight_sel, '(A,"_W_sel")') trim(tp)
1156 call emit_node('Gather', trim(tp)//'_gather_W', &
1157 trim(weight_sel), onnx_axis0_attr, nodes, num_nodes, &
1158
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(weight_name), in2=trim(deg_idx))
1159
1160 ! Divide by degree and reshape for batched MatMul.
1161
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(axes1_name, '(A,"_us_ax1_deg")') trim(tp)
1162 call emit_constant_int64(trim(axes1_name), [1], [1], &
1163
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1164
1165
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(deg_us, '(A,"_deg_us")') trim(tp)
1166 call emit_node('Unsqueeze', trim(tp)//'_us_deg', &
1167 trim(deg_us), '', nodes, num_nodes, &
1168
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(deg_clip), in2=trim(axes1_name))
1169
1170
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(aggr_norm, '(A,"_aggr_norm")') trim(tp)
1171 call emit_node('Div', trim(tp)//'_div_deg', &
1172 trim(aggr_norm), '', nodes, num_nodes, &
1173
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(aggr_in), in2=trim(deg_us))
1174
1175
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(axes2_name, '(A,"_us_ax2")') trim(tp)
1176 call emit_constant_int64(trim(axes2_name), [2], [1], &
1177
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1178
1179
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(aggr_us, '(A,"_aggr_us")') trim(tp)
1180 call emit_node('Unsqueeze', trim(tp)//'_us_aggr', &
1181 trim(aggr_us), '', nodes, num_nodes, &
1182
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(aggr_norm), in2=trim(axes2_name))
1183
1184
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(matmul_out, '(A,"_matmul_out")') trim(tp)
1185 call emit_node('MatMul', trim(tp)//'_matmul', &
1186 trim(matmul_out), '', nodes, num_nodes, &
1187
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(weight_sel), in2=trim(aggr_us))
1188
1189
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(sq_out, '(A,"_sq_out")') trim(tp)
1190 call emit_squeeze_node(trim(tp)//'_sq_mm', &
1191 trim(matmul_out), trim(axes2_name), trim(sq_out), &
1192
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✓ Branch 17 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 22 not taken.
2 nodes, num_nodes)
1193
1194
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 end subroutine emit_duvenaud_degree_update
1195 !###############################################################################
1196
1197
1198 !###############################################################################
1199 2 subroutine emit_duvenaud_readout_step( &
1200 2 prefix, activation_name, t, nv, no, weight_data, &
1201
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
2 nodes, num_nodes, inits, num_inits, step_sum)
1202 !! Emit one Duvenaud readout timestep.
1203 !!
1204 !! This expands to the timestep readout projection, the readout softmax,
1205 !! and the reduction over nodes before the timestep contributions are added.
1206 use athena__onnx_utils, only: emit_node, emit_constant_int64
1207 use athena__onnx_msgpass_utils, only: get_timestep_output_name, &
1208 emit_weight_initialiser_2d
1209 implicit none
1210
1211 ! Arguments
1212 character(*), intent(in) :: prefix
1213 character(*), intent(in) :: activation_name
1214 integer, intent(in) :: t, nv, no
1215 real(real32), intent(in) :: weight_data(:)
1216 type(onnx_node_type), intent(inout), dimension(:) :: nodes
1217 integer, intent(inout) :: num_nodes
1218 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
1219 integer, intent(inout) :: num_inits
1220 character(128), intent(out) :: step_sum
1221
1222 ! Local variables
1223 character(128) :: tp, z_name, weight_name, z_transpose
1224 character(128) :: matmul_out, softmax_out, axis1_name
1225 character(len=*), parameter :: onnx_softmax_axis0_attr = &
1226 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
1227 character(len=*), parameter :: onnx_transpose_10_attr = &
1228 ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
1229 '"type": "INTS"}]'
1230 character(len=*), parameter :: onnx_reduce_sum_attr = &
1231 ' "attribute": [{"name": "keepdims", "i": "0", ' // &
1232 '"type": "INT"}]'
1233
1234
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(tp, '(A,"_ro_t",I0)') trim(prefix), t
1235 call get_timestep_output_name( &
1236 2 prefix, t, activation_name, '_sq_out', '_sq', z_name)
1237
1238 ! Store the readout matrix for timestep t as an ONNX initialiser.
1239
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(weight_name, '(A,"_R")') trim(tp)
1240 call emit_weight_initialiser_2d( &
1241
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 trim(weight_name), no, nv, weight_data, inits, num_inits)
1242
1243 ! Transpose node features before multiplying by the readout matrix.
1244
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(z_transpose, '(A,"_zt")') trim(tp)
1245 call emit_node('Transpose', trim(tp)//'_transpose_z', &
1246 trim(z_transpose), onnx_transpose_10_attr, nodes, num_nodes, &
1247
6/12
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 in1=trim(z_name))
1248
1249
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(matmul_out, '(A,"_Rz")') trim(tp)
1250 call emit_node('MatMul', trim(tp)//'_matmul_R', &
1251 trim(matmul_out), '', nodes, num_nodes, &
1252
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(weight_name), in2=trim(z_transpose))
1253
1254 ! Softmax and ReduceSum reproduce the ATHENA readout accumulation.
1255
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(softmax_out, '(A,"_sm")') trim(tp)
1256 call emit_node('Softmax', trim(tp)//'_softmax', &
1257 trim(softmax_out), onnx_softmax_axis0_attr, nodes, num_nodes, &
1258
6/12
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
2 in1=trim(matmul_out))
1259
1260
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(axis1_name, '(A,"_ax1")') trim(tp)
1261 call emit_constant_int64(trim(axis1_name), [1], [1], &
1262
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
2 nodes, num_nodes, inits, num_inits)
1263
1264
1/2
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
2 write(step_sum, '(A,"_sum")') trim(tp)
1265 call emit_node('ReduceSum', trim(tp)//'_reducesum', &
1266 trim(step_sum), onnx_reduce_sum_attr, nodes, num_nodes, &
1267
7/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
2 in1=trim(softmax_out), in2=trim(axis1_name))
1268
1269
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 end subroutine emit_duvenaud_readout_step
1270 !###############################################################################
1271
1272
1273 !###############################################################################
1274 1 subroutine emit_duvenaud_readout_impl( &
1275 prefix, layer, &
1276 1 nodes, num_nodes, max_nodes, &
1277
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 inits, num_inits, max_inits, &
1278
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 readout_output &
1279 )
1280 !! Emit ONNX nodes for Duvenaud readout
1281 use athena__onnx_utils, only: emit_node, emit_constant_int64
1282 implicit none
1283 character(*), intent(in) :: prefix
1284 class(duvenaud_msgpass_layer_type), intent(in) :: layer
1285 type(onnx_node_type), intent(inout), dimension(:) :: nodes
1286 integer, intent(inout) :: num_nodes
1287 integer, intent(in) :: max_nodes
1288 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
1289 integer, intent(inout) :: num_inits
1290 integer, intent(in) :: max_inits
1291 character(128), intent(out) :: readout_output
1292
1293 ! Local variables
1294 integer :: t
1295 character(128) :: tmp1, prev_accum, step_sum
1296
1297
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
3 do t = 1, layer%num_time_steps
1298 call emit_duvenaud_readout_step( &
1299 prefix, layer%activation%name, t, &
1300 4 layer%num_vertex_features(t), layer%num_outputs, &
1301 20 layer%params(t + layer%num_time_steps)%val(:,1), &
1302
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 2 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 2 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 2 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 2 times.
2 nodes, num_nodes, inits, num_inits, step_sum)
1303
1304 ! Accumulate across timesteps
1305
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
3 if(t .eq. 1)then
1306
2/4
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 prev_accum = trim(step_sum)
1307 else
1308
1/2
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 write(tmp1, '(A,"_ro_t",I0,"_accum")') trim(prefix), t
1309 call emit_node('Add', trim(tmp1)//'_node', &
1310 trim(tmp1), '', nodes, num_nodes, &
1311
7/14
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 14 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
1 in1=trim(prev_accum), in2=trim(step_sum))
1312
2/4
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 prev_accum = trim(tmp1)
1313 end if
1314 end do
1315
1316 ! Unsqueeze to add batch dimension: [no] → [1, no]
1317
1/2
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 write(tmp1, '(A,"_ro_ax0")') trim(prefix)
1318 call emit_constant_int64(trim(tmp1), [0], [1], &
1319
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
1 nodes, num_nodes, inits, num_inits)
1320
1/2
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 write(readout_output, '(A,"_readout")') trim(prefix)
1321 call emit_node('Unsqueeze', trim(prefix)//'_us_readout', &
1322 trim(readout_output), '', nodes, num_nodes, &
1323
7/14
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 14 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
1 in1=trim(prev_accum), in2=trim(tmp1))
1324
1325
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 end subroutine emit_duvenaud_readout_impl
1326 !###############################################################################
1327
1328
94/185
✓ Branch 0 taken 17 times.
✓ Branch 1 taken 18 times.
✓ Branch 2 taken 17 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✓ Branch 5 taken 17 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 17 times.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 39 taken 17 times.
✓ Branch 40 taken 18 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 17 times.
✓ Branch 43 taken 18 times.
✓ Branch 44 taken 17 times.
✓ Branch 45 taken 35 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 17 times.
✓ Branch 48 taken 18 times.
✓ Branch 49 taken 35 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 35 times.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✓ Branch 54 taken 17 times.
✓ Branch 55 taken 17 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 62 times.
✓ Branch 58 taken 17 times.
✓ Branch 59 taken 62 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 62 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 31 times.
✓ Branch 64 taken 31 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 62 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 62 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 80 times.
✓ Branch 71 taken 17 times.
✓ Branch 72 taken 2 times.
✓ Branch 73 taken 16 times.
✓ Branch 74 taken 18 times.
✓ Branch 75 taken 17 times.
✓ Branch 76 taken 1 times.
✓ Branch 77 taken 34 times.
✓ Branch 78 taken 1 times.
✓ Branch 79 taken 1 times.
✓ Branch 80 taken 17 times.
✓ Branch 81 taken 1 times.
✓ Branch 82 taken 17 times.
✓ Branch 83 taken 1 times.
✓ Branch 84 taken 17 times.
✓ Branch 85 taken 1 times.
✓ Branch 86 taken 18 times.
✗ Branch 87 not taken.
✓ Branch 88 taken 1 times.
✗ Branch 89 not taken.
✓ Branch 90 taken 18 times.
✗ Branch 91 not taken.
✗ Branch 92 not taken.
✓ Branch 93 taken 1 times.
✗ Branch 94 not taken.
✓ Branch 95 taken 1 times.
✗ Branch 96 not taken.
✓ Branch 97 taken 1 times.
✗ Branch 98 not taken.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✗ Branch 101 not taken.
✗ Branch 102 not taken.
✗ Branch 103 not taken.
✗ Branch 104 not taken.
✓ Branch 105 taken 1 times.
✓ Branch 106 taken 17 times.
✗ Branch 107 not taken.
✗ Branch 108 not taken.
✗ Branch 109 not taken.
✗ Branch 110 not taken.
✗ Branch 111 not taken.
✓ Branch 112 taken 17 times.
✓ Branch 113 taken 1 times.
✓ Branch 114 taken 18 times.
✗ Branch 115 not taken.
✓ Branch 116 taken 18 times.
✗ Branch 117 not taken.
✗ Branch 118 not taken.
✗ Branch 119 not taken.
✗ Branch 120 not taken.
✓ Branch 121 taken 18 times.
✗ Branch 122 not taken.
✓ Branch 123 taken 18 times.
✗ Branch 124 not taken.
✓ Branch 125 taken 18 times.
✗ Branch 126 not taken.
✓ Branch 127 taken 18 times.
✗ Branch 128 not taken.
✓ Branch 129 taken 18 times.
✗ Branch 130 not taken.
✓ Branch 131 taken 2 times.
✓ Branch 132 taken 16 times.
✓ Branch 133 taken 18 times.
✗ Branch 134 not taken.
✓ Branch 135 taken 18 times.
✗ Branch 136 not taken.
✓ Branch 137 taken 18 times.
✗ Branch 138 not taken.
✓ Branch 139 taken 2 times.
✓ Branch 140 taken 16 times.
✓ Branch 142 taken 18 times.
✗ Branch 143 not taken.
✓ Branch 144 taken 2 times.
✓ Branch 145 taken 16 times.
✗ Branch 146 not taken.
✓ Branch 147 taken 16 times.
✓ Branch 149 taken 18 times.
✗ Branch 150 not taken.
✓ Branch 151 taken 18 times.
✗ Branch 152 not taken.
✗ Branch 153 not taken.
✗ Branch 154 not taken.
✓ Branch 156 taken 18 times.
✗ Branch 157 not taken.
✓ Branch 158 taken 2 times.
✓ Branch 159 taken 16 times.
✗ Branch 160 not taken.
✓ Branch 161 taken 16 times.
✓ Branch 163 taken 18 times.
✗ Branch 164 not taken.
✓ Branch 165 taken 2 times.
✓ Branch 166 taken 16 times.
✓ Branch 167 taken 18 times.
✗ Branch 168 not taken.
✓ Branch 169 taken 2 times.
✓ Branch 170 taken 16 times.
✓ Branch 171 taken 18 times.
✗ Branch 172 not taken.
✓ Branch 173 taken 2 times.
✓ Branch 174 taken 16 times.
✓ Branch 175 taken 18 times.
✗ Branch 176 not taken.
✓ Branch 177 taken 2 times.
✓ Branch 178 taken 16 times.
✗ Branch 179 not taken.
✓ Branch 180 taken 16 times.
✓ Branch 182 taken 18 times.
✗ Branch 183 not taken.
✓ Branch 184 taken 17 times.
✓ Branch 185 taken 1 times.
✓ Branch 187 taken 18 times.
✗ Branch 188 not taken.
✓ Branch 189 taken 18 times.
✗ Branch 190 not taken.
334 end module athena__duvenaud_msgpass_layer
1329