GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_kipf_msgpass_layer.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 241 291 82.8%
Functions: 0 0 -%
Branches: 748 1599 46.8%

Line Branch Exec Source
1 module athena__kipf_msgpass_layer
2 !! Module implementing Kipf & Welling Graph Convolutional Network (GCN)
3 !!
4 !! This module implements the graph convolutional layer from Kipf & Welling
5 !! (2017) with symmetric degree normalisation for semi-supervised learning.
6 !!
7 !! Mathematical operation:
8 !! \[ H^{(l+1)} = \sigma\left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) \]
9 !!
10 !! where:
11 !! * \( \tilde{A} = A + I \) (adjacency matrix with added self-loops)
12 !! * \( \tilde{D} \) is the degree matrix of \( \tilde{A} \)
13 !! * \( H^{(l)} \) is the node feature matrix at layer l
14 !! * \( W^{(l)} \) is a learnable weight matrix
15 !! * \( \sigma \) is the activation function
16 !!
17 !! The normalisation \( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} \) ensures
18 !! proper scaling by degree.
19 !! Preserves graph structure, producing node-level (not graph-level) outputs.
20 !!
21 !! Reference: Kipf & Welling (2017), ICLR
22 use coreutils, only: real32, stop_program
23 use graphstruc, only: graph_type
24 use athena__misc_types, only: base_actv_type, base_init_type, &
25 onnx_attribute_type, onnx_node_type, onnx_initialiser_type, &
26 onnx_tensor_type
27 use diffstruc, only: array_type
28 use athena__base_layer, only: base_layer_type
29 use athena__msgpass_layer, only: msgpass_layer_type
30 use athena__diffstruc_extd, only: kipf_propagate, kipf_update
31 use diffstruc, only: matmul
32 implicit none
33
34
35 private
36
37 public :: kipf_msgpass_layer_type
38 public :: read_kipf_msgpass_layer
39
40
41 !-------------------------------------------------------------------------------
42 ! Message passing layer
43 !-------------------------------------------------------------------------------
44 type, extends(msgpass_layer_type) :: kipf_msgpass_layer_type
45
46 ! this is for chen 2021 et al
47 ! type(array2d_type), dimension(:), allocatable :: edge_weight
48 ! !! Weights for the edges
49 ! type(array2d_type), dimension(:), allocatable :: vertex_weight
50 ! !! Weights for the vertices
51
52 contains
53 procedure, pass(this) :: get_num_params => get_num_params_kipf
54 !! Get the number of parameters for the message passing layer
55 procedure, pass(this) :: set_hyperparams => set_hyperparams_kipf
56 !! Set the hyperparameters for the message passing layer
57 procedure, pass(this) :: init => init_kipf
58 !! Initialise the message passing layer
59 procedure, pass(this) :: print_to_unit => print_to_unit_kipf
60 !! Print the message passing layer
61 procedure, pass(this) :: read => read_kipf
62 !! Read the message passing layer
63
64 procedure, pass(this) :: update_message => update_message_kipf
65 !! Update the message
66
67 procedure, pass(this) :: update_readout => update_readout_kipf
68 !! Update the readout
69
70 procedure, pass(this) :: get_attributes => get_attributes_kipf
71 !! Get the attributes of the layer (for ONNX export)
72 procedure, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_kipf
73 !! Emit ONNX JSON nodes for Kipf GCN layer
74 end type kipf_msgpass_layer_type
75
76 ! Interface for setting up the MPNN layer
77 !-----------------------------------------------------------------------------
78 interface kipf_msgpass_layer_type
79 !! Interface for setting up the MPNN layer
80 module function layer_setup( &
81 num_vertex_features, num_time_steps, &
82 activation, &
83 kernel_initialiser, &
84 verbose &
85 ) result(layer)
86 !! Set up the message passing layer
87 integer, dimension(:), intent(in) :: num_vertex_features
88 !! Number of features
89 integer, intent(in) :: num_time_steps
90 !! Number of time steps
91 class(*), optional, intent(in) :: activation, kernel_initialiser
92 !! Activation function and kernel initialiser
93 integer, optional, intent(in) :: verbose
94 !! Verbosity level
95 type(kipf_msgpass_layer_type) :: layer
96 !! Instance of the message passing layer
97 end function layer_setup
98 end interface kipf_msgpass_layer_type
99
100 contains
101
102
103 !##############################################################################!
104 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
105 !##############################################################################!
106
107
108 !###############################################################################
109 18 pure function get_num_params_kipf(this) result(num_params)
110 !! Get the number of parameters for the message passing layer
111 !!
112 !! This function calculates the number of parameters for the message passing
113 !! layer.
114 !! This procedure is based on code from the neural-fortran library
115 implicit none
116
117 ! Arguments
118 class(kipf_msgpass_layer_type), intent(in) :: this
119 !! Instance of the message passing layer
120 integer :: num_params
121 !! Number of parameters
122
123 ! Local variables
124 integer :: t
125 !! Loop index
126
127 18 num_params = 0
128
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
36 do t = 1, this%num_time_steps
129 num_params = num_params + &
130
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
36 this%num_vertex_features(t-1) * this%num_vertex_features(t)
131 end do
132
133 18 end function get_num_params_kipf
134 !###############################################################################
135
136
137 !##############################################################################!
138 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
139 !##############################################################################!
140
141
142 !###############################################################################
143 17 module function layer_setup( &
144 17 num_vertex_features, num_time_steps, &
145 activation, &
146 kernel_initialiser, &
147 verbose &
148 17 ) result(layer)
149 !! Set up the message passing layer
150 use athena__activation, only: activation_setup
151 use athena__initialiser, only: initialiser_setup
152 implicit none
153
154 ! Arguments
155 integer, dimension(:), intent(in) :: num_vertex_features
156 !! Number of features
157 integer, intent(in) :: num_time_steps
158 !! Number of time steps
159 class(*), optional, intent(in) :: activation, kernel_initialiser
160 !! Activation function and kernel initialiser
161 integer, optional, intent(in) :: verbose
162 !! Verbosity level
163 type(kipf_msgpass_layer_type) :: layer
164 !! Instance of the message passing layer
165
166 ! Local variables
167 integer :: verbose_ = 0
168 !! Verbosity level
169 51 class(base_actv_type), allocatable :: activation_
170 !! Activation function object
171 37 class(base_init_type), allocatable :: kernel_initialiser_
172 !! Kernel initialisers
173
174
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
17 if(present(verbose)) verbose_ = verbose
175
176
177 !---------------------------------------------------------------------------
178 ! Set activation functions based on input name
179 !---------------------------------------------------------------------------
180
3/4
✓ Branch 0 taken 9 times.
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
17 if(present(activation))then
181
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 9 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 9 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 9 times.
✓ Branch 17 taken 9 times.
✗ Branch 18 not taken.
9 activation_ = activation_setup(activation)
182 else
183
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 8 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 8 times.
✓ Branch 17 taken 8 times.
✗ Branch 18 not taken.
8 activation_ = activation_setup("none")
184 end if
185
186
187 !---------------------------------------------------------------------------
188 ! Define weights (kernels) and biases initialisers
189 !---------------------------------------------------------------------------
190
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
17 if(present(kernel_initialiser))then
191
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
3 kernel_initialiser_ = initialiser_setup(kernel_initialiser)
192 end if
193
194
195 !---------------------------------------------------------------------------
196 ! Set hyperparameters
197 !---------------------------------------------------------------------------
198 call layer%set_hyperparams( &
199 num_vertex_features = num_vertex_features, &
200 num_time_steps = num_time_steps, &
201 activation = activation_, &
202 kernel_initialiser = kernel_initialiser_, &
203 verbose = verbose_ &
204
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 17 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 17 times.
17 )
205
206
207 !---------------------------------------------------------------------------
208 ! Initialise layer shape
209 !---------------------------------------------------------------------------
210
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 17 times.
✓ Branch 6 taken 34 times.
✓ Branch 7 taken 17 times.
51 call layer%init(input_shape=[layer%num_vertex_features(0), 0])
211
212
6/10
✓ Branch 0 taken 17 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✓ Branch 3 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✓ Branch 7 taken 17 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 17 times.
34 end function layer_setup
213 !###############################################################################
214
215
216 !###############################################################################
217 18 subroutine set_hyperparams_kipf( &
218 this, &
219
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 num_vertex_features, &
220 num_time_steps, &
221 activation, &
222 kernel_initialiser, &
223 verbose &
224 )
225 !! Set the hyperparameters for the message passing layer
226 use athena__activation, only: activation_setup
227 use athena__initialiser, only: get_default_initialiser, initialiser_setup
228 implicit none
229
230 ! Arguments
231 class(kipf_msgpass_layer_type), intent(inout) :: this
232 !! Instance of the message passing layer
233 integer, dimension(:), intent(in) :: num_vertex_features
234 !! Number of vertex features
235 integer, intent(in) :: num_time_steps
236 !! Number of time steps
237 class(base_actv_type), allocatable, intent(in) :: activation
238 !! Activation function
239 class(base_init_type), allocatable, intent(in) :: kernel_initialiser
240 !! Kernel initialiser
241 integer, optional, intent(in) :: verbose
242 !! Verbosity level
243
244 ! Local variables
245 integer :: t
246 !! Loop index
247 character(len=256) :: buffer
248
249
250
5/8
✓ Branch 0 taken 17 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 18 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
18 this%name = 'kipf'
251 18 this%type = 'msgp'
252 18 this%input_rank = 2
253 18 this%output_rank = 2
254 18 this%use_graph_output = .true.
255 18 this%num_time_steps = num_time_steps
256
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
18 if(allocated(this%num_vertex_features)) &
257
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(this%num_vertex_features)
258
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
18 if(allocated(this%num_edge_features)) &
259
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(this%num_edge_features)
260
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
18 if(size(num_vertex_features, 1) .eq. 1)then
261 allocate( &
262 this%num_vertex_features(0:num_time_steps), &
263 source = num_vertex_features(1) &
264 )
265
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✓ Branch 9 taken 18 times.
✗ Branch 10 not taken.
18 elseif(size(num_vertex_features, 1) .eq. num_time_steps + 1)then
266 allocate( &
267 this%num_vertex_features(0:this%num_time_steps), &
268 source = num_vertex_features &
269
20/38
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✓ Branch 6 taken 18 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 18 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 18 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 18 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 18 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 18 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 18 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 18 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 18 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 18 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 18 times.
✓ Branch 43 taken 36 times.
✓ Branch 44 taken 18 times.
54 )
270 else
271 call stop_program( &
272 "Error: num_vertex_features must be a scalar or a vector of length &
273 &num_time_steps + 1" &
274 )
275 end if
276
13/24
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 18 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 18 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 18 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 18 times.
✓ Branch 29 taken 36 times.
✓ Branch 30 taken 18 times.
54 allocate( this%num_edge_features(0:this%num_time_steps), source = 0 )
277 18 this%use_graph_input = .true.
278
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
18 if(allocated(this%activation)) deallocate(this%activation)
279
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
18 if(.not.allocated(activation))then
280
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 this%activation = activation_setup("none")
281 else
282
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 17 times.
17 allocate(this%activation, source=activation)
283 end if
284
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
18 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
285
2/2
✓ Branch 0 taken 14 times.
✓ Branch 1 taken 4 times.
18 if(.not.allocated(kernel_initialiser))then
286
1/2
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
14 buffer = get_default_initialiser(this%activation%name)
287
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 14 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 14 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 14 times.
✓ Branch 17 taken 14 times.
✗ Branch 18 not taken.
14 this%kernel_init = initialiser_setup(buffer)
288 else
289
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
4 allocate(this%kernel_init, source=kernel_initialiser)
290 end if
291
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 if(present(verbose))then
292
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
18 if(abs(verbose).gt.0)then
293 write(*,'("KIPF activation function: ",A)') &
294 trim(this%activation%name)
295 write(*,'("KIPF kernel initialiser: ",A)') &
296 trim(this%kernel_init%name)
297 end if
298 end if
299
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
18 if(allocated(this%num_params_msg)) deallocate(this%num_params_msg)
300
7/14
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
18 allocate(this%num_params_msg(1:this%num_time_steps))
301
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
36 do t = 1, this%num_time_steps
302 36 this%num_params_msg(t) = &
303
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
36 this%num_vertex_features(t-1) * this%num_vertex_features(t)
304 end do
305
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
18 if(allocated(this%input_shape)) deallocate(this%input_shape)
306
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
18 if(allocated(this%output_shape)) deallocate(this%output_shape)
307
308 18 end subroutine set_hyperparams_kipf
309 !###############################################################################
310
311
312 !###############################################################################
313
1/2
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
18 subroutine init_kipf(this, input_shape, verbose)
314 !! Initialise the message passing layer
315 use athena__initialiser, only: initialiser_setup
316 implicit none
317
318 ! Arguments
319 class(kipf_msgpass_layer_type), intent(inout) :: this
320 !! Instance of the fully connected layer
321 integer, dimension(:), intent(in) :: input_shape
322 !! Input shape
323 integer, optional, intent(in) :: verbose
324 !! Verbosity level
325
326 ! Local variables
327 integer :: t
328 !! Loop index
329 integer :: verbose_ = 0
330 !! Verbosity level
331
332
333 !---------------------------------------------------------------------------
334 ! Initialise optional arguments
335 !---------------------------------------------------------------------------
336
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
18 if(present(verbose)) verbose_ = verbose
337
338
339 !---------------------------------------------------------------------------
340 ! Initialise number of inputs
341 !---------------------------------------------------------------------------
342
4/8
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
18 if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
343
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 36 times.
✓ Branch 13 taken 18 times.
54 this%output_shape = [this%num_vertex_features(this%num_time_steps), 0]
344 18 this%num_params = this%get_num_params()
345
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
18 if(allocated(this%weight_shape)) deallocate(this%weight_shape)
346
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
18 if(allocated(this%bias_shape)) deallocate(this%bias_shape)
347
7/14
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
18 allocate(this%weight_shape(2,this%num_time_steps))
348
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
36 do t = 1, this%num_time_steps
349 126 this%weight_shape(:,t) = &
350
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 18 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 18 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 18 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 18 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 18 times.
✓ Branch 33 taken 36 times.
✓ Branch 34 taken 18 times.
72 [ this%num_vertex_features(t), this%num_vertex_features(t-1) ]
351 end do
352
353
354 !---------------------------------------------------------------------------
355 ! Allocate weight, weight steps (velocities), output, and activation
356 !---------------------------------------------------------------------------
357
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 17 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
18 if(allocated(this%params)) deallocate(this%params)
358
20/38
✓ Branch 0 taken 18 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 18 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 18 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 18 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 18 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 18 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 18 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 18 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 18 times.
✓ Branch 29 taken 18 times.
✓ Branch 30 taken 18 times.
✓ Branch 31 taken 18 times.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 18 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 18 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 18 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 18 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 18 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 18 times.
36 allocate(this%params(this%num_time_steps))
359
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
36 do t = 1, this%num_time_steps
360 36 call this%params(t)%allocate( &
361 108 array_shape = [ this%weight_shape(:,t), 1 ] &
362
14/24
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 18 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 18 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 18 times.
✓ Branch 27 taken 36 times.
✓ Branch 28 taken 18 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 18 times.
✓ Branch 31 taken 54 times.
✓ Branch 32 taken 18 times.
108 )
363
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 call this%params(t)%set_requires_grad(.true.)
364
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 this%params(t)%is_sample_dependent = .false.
365
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
18 this%params(t)%is_temporary = .false.
366
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
36 this%params(t)%fix_pointer = .true.
367 end do
368
369
370 !---------------------------------------------------------------------------
371 ! Initialise weights (kernels)
372 !---------------------------------------------------------------------------
373
2/2
✓ Branch 0 taken 18 times.
✓ Branch 1 taken 18 times.
36 do t = 1, this%num_time_steps
374 call this%kernel_init%initialise( &
375 180 this%params(t)%val(:,1), &
376 36 fan_in = this%num_vertex_features(t-1), &
377 36 fan_out = this%num_vertex_features(t), &
378 36 spacing = [ this%num_vertex_features(t) ] &
379
18/34
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 18 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 18 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 18 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 18 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 18 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 18 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 18 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 18 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 18 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 18 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 18 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 18 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 18 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 18 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 18 times.
✓ Branch 48 taken 18 times.
✓ Branch 49 taken 18 times.
54 )
380 end do
381
382
383 !---------------------------------------------------------------------------
384 ! Allocate arrays
385 !---------------------------------------------------------------------------
386
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 18 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
18 if(allocated(this%output)) deallocate(this%output)
387
388 18 end subroutine init_kipf
389 !###############################################################################
390
391
392 !##############################################################################!
393 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
394 !##############################################################################!
395
396
397 !###############################################################################
398 1 subroutine print_to_unit_kipf(this, unit)
399 !! Print kipf message passing layer to unit
400 use coreutils, only: to_upper
401 implicit none
402
403 ! Arguments
404 class(kipf_msgpass_layer_type), intent(in) :: this
405 !! Instance of the message passing layer
406 integer, intent(in) :: unit
407 !! File unit
408
409 ! Local variables
410 integer :: t
411 !! Loop index
412 character(100) :: fmt
413 !! Format string
414
415
416 ! Write initial parameters
417 !---------------------------------------------------------------------------
418 1 write(unit,'(3X,"NUM_TIME_STEPS = ",I0)') this%num_time_steps
419 write(fmt,'("(3X,""NUM_VERTEX_FEATURES ="",",I0,"(1X,I0))")') &
420 1 this%num_time_steps + 1
421 1 write(unit,fmt) this%num_vertex_features
422
423
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(this%activation%name .ne. 'none')then
424 call this%activation%print_to_unit(unit)
425 end if
426
427
428 ! Write learned parameters
429 !---------------------------------------------------------------------------
430 1 write(unit,'("WEIGHTS")')
431
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 do t = 1, this%num_time_steps, 1
432
14/24
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✓ Branch 31 taken 1 times.
✓ Branch 32 taken 1 times.
✓ Branch 33 taken 50 times.
✓ Branch 34 taken 1 times.
53 write(unit,'(5(E16.8E2))') this%params(t)%val
433 end do
434 1 write(unit,'("END WEIGHTS")')
435
436 1 end subroutine print_to_unit_kipf
437 !###############################################################################
438
439
440 !###############################################################################
441 1 subroutine read_kipf(this, unit, verbose)
442 !! Read the message passing layer
443 use athena__tools_infile, only: assign_val, assign_vec, get_val, move
444 use coreutils, only: to_lower, to_upper, icount
445 use athena__activation, only: read_activation
446 use athena__initialiser, only: initialiser_setup
447 implicit none
448
449 ! Arguments
450 class(kipf_msgpass_layer_type), intent(inout) :: this
451 !! Instance of the message passing layer
452 integer, intent(in) :: unit
453 !! Unit to read from
454 integer, optional, intent(in) :: verbose
455 !! Verbosity level
456
457 ! Local variables
458 integer :: stat
459 !! Status of read
460 integer :: verbose_ = 0
461 !! Verbosity level
462 integer :: t, j, k, c, itmp1, iline
463 !! Loop variables and temporary integer
464 integer :: num_time_steps = 0
465 !! Number of time steps
466 character(14) :: kernel_initialiser_name=''
467 !! Initialisers
468 character(20) :: activation_name=''
469 !! Activation function name
470 2 class(base_actv_type), allocatable :: activation
471 !! Activation function
472 3 class(base_init_type), allocatable :: kernel_initialiser
473 !! Initialisers
474 1 integer, dimension(:), allocatable :: num_vertex_features
475 !! Number of vertex and edge features
476 character(256) :: buffer, tag, err_msg
477 !! Buffer, tag, and error message
478 1 real(real32), allocatable, dimension(:) :: data_list
479 !! Data list
480 integer :: param_line, final_line
481 !! Parameter line number
482
483
484 ! Initialise optional arguments
485 !---------------------------------------------------------------------------
486
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(verbose)) verbose_ = verbose
487
488
489 ! Loop over tags in layer card
490 !---------------------------------------------------------------------------
491 1 iline = 0
492 1 param_line = 0
493 1 final_line = 0
494 14 tag_loop: do
495
496 ! Check for end of file
497 !------------------------------------------------------------------------
498 15 read(unit,'(A)',iostat=stat) buffer
499
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
15 if(stat.ne.0)then
500 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
501 to_upper(this%name)
502 call stop_program(err_msg)
503 return
504 end if
505
2/4
✓ Branch 2 taken 15 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 15 times.
15 if(trim(adjustl(buffer)).eq."") cycle tag_loop
506
507 ! Check for end of layer card
508 !------------------------------------------------------------------------
509
4/6
✓ Branch 3 taken 15 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 15 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 14 times.
30 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
510 1 final_line = iline
511 1 backspace(unit)
512 15 exit tag_loop
513 end if
514 14 iline = iline + 1
515
516
2/4
✓ Branch 2 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 5 not taken.
14 tag=trim(adjustl(buffer))
517
6/10
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
14 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
518
519 ! Read parameters from file
520 !------------------------------------------------------------------------
521 28 select case(trim(tag))
522 case("NUM_TIME_STEPS")
523 2 call assign_val(buffer, num_time_steps, itmp1)
524 case("NUM_VERTEX_FEATURES")
525 1 itmp1 = icount(get_val(buffer))
526
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(num_vertex_features(itmp1), source=0)
527 1 call assign_vec(buffer, num_vertex_features, itmp1)
528 case("ACTIVATION")
529 iline = iline - 1
530 backspace(unit)
531 activation = read_activation(unit, iline)
532 case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALisER")
533 call assign_val(buffer, kernel_initialiser_name, itmp1)
534 case("WEIGHTS")
535 1 kernel_initialiser_name = 'zeros'
536 1 param_line = iline
537 case default
538 ! Don't look for "e" due to scientific notation of numbers
539 ! ... i.e. exponent (E+00)
540
3/4
✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✓ Branch 5 taken 1 times.
22 if(scan(to_lower(trim(adjustl(buffer))),&
541 'abcdfghijklmnopqrstuvwxyz').eq.0)then
542 11 cycle tag_loop
543
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 elseif(tag(:3).eq.'END')then
544 11 cycle tag_loop
545 end if
546 write(err_msg,'("Unrecognised line in input file: ",A)') &
547 trim(adjustl(buffer))
548 call stop_program(err_msg)
549
5/8
✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 11 times.
28 return
550 end select
551 end do tag_loop
552
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 kernel_initialiser = initialiser_setup(kernel_initialiser_name)
553
554
555 ! Set hyperparameters and initialise layer
556 !---------------------------------------------------------------------------
557
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(num_time_steps.gt.0 .and. num_time_steps.ne.size(num_vertex_features,1)-1)then
558 write(err_msg,'("NUM_TIME_STEPS = ",I0," does not match length of "// &
559 &"NUM_VERTEX_FEATURES = ",I0)') num_time_steps, &
560 size(num_vertex_features,1)-1
561 call stop_program(err_msg)
562 return
563 end if
564 call this%set_hyperparams( &
565 num_time_steps = num_time_steps, &
566 num_vertex_features = num_vertex_features, &
567 activation = activation, &
568 kernel_initialiser = kernel_initialiser, &
569 verbose = verbose_ &
570 1 )
571
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 6 taken 2 times.
✓ Branch 7 taken 1 times.
3 call this%init(input_shape=[this%num_vertex_features(0), 0])
572
573
574 ! Check if WEIGHTS card was found
575 !---------------------------------------------------------------------------
576
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(param_line.eq.0)then
577 write(0,*) "WARNING: WEIGHTS card in "//to_upper(trim(this%name))//" not found"
578 else
579 1 call move(unit, param_line - iline, iostat=stat)
580
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 do t = 1, this%num_time_steps
581
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
1 allocate(data_list(this%num_params_msg(t)), source=0._real32)
582 1 c = 1
583 1 k = 1
584
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 10 times.
11 data_concat_loop: do while(c.le.this%num_params_msg(t))
585 10 read(unit,'(A)',iostat=stat) buffer
586
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 if(stat.ne.0) exit data_concat_loop
587 10 k = icount(buffer)
588
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 60 times.
✓ Branch 3 taken 50 times.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 50 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 50 times.
60 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
589 10 c = c + k
590 end do data_concat_loop
591
17/32
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✓ Branch 45 taken 50 times.
✓ Branch 46 taken 1 times.
51 this%params(t)%val(:,1) = data_list(1:this%num_params_msg(t))
592
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
2 deallocate(data_list)
593 end do
594
595 ! Check for end of weights card
596 !------------------------------------------------------------------------
597 1 read(unit,'(A)') buffer
598
2/4
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
1 if(trim(adjustl(buffer)).ne."END WEIGHTS")then
599 write(0,*) trim(adjustl(buffer))
600 call stop_program("END WEIGHTS not where expected")
601 return
602 end if
603 end if
604
605
606 !---------------------------------------------------------------------------
607 ! Check for end of layer card
608 !---------------------------------------------------------------------------
609 1 read(unit,'(A)') buffer
610
3/6
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
2 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
611 write(0,*) trim(adjustl(buffer))
612 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
613 call stop_program(err_msg)
614 1 return
615 end if
616
617
5/12
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
1 end subroutine read_kipf
618 !###############################################################################
619
620
621 !###############################################################################
622 1 function read_kipf_msgpass_layer(unit, verbose) result(layer)
623 !! Read kipf message passing layer from file and return layer
624 implicit none
625
626 ! Arguments
627 integer, intent(in) :: unit
628 !! Unit number
629 integer, optional, intent(in) :: verbose
630 !! Verbosity level
631 class(base_layer_type), allocatable :: layer
632 !! Instance of the message passing layer
633
634 ! Local variables
635 integer :: verbose_ = 0
636 !! Verbosity level
637
638
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(verbose)) verbose_ = verbose
639 allocate(layer, source = kipf_msgpass_layer_type( &
640 num_time_steps = 1, &
641 num_vertex_features = [ 0, 0 ] &
642
23/84
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 1 times.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✓ Branch 59 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 62 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✓ Branch 69 taken 1 times.
✗ Branch 70 not taken.
✗ Branch 71 not taken.
✓ Branch 72 taken 1 times.
✗ Branch 73 not taken.
✗ Branch 74 not taken.
✗ Branch 76 not taken.
✓ Branch 77 taken 1 times.
✓ Branch 78 taken 1 times.
✗ Branch 79 not taken.
✗ Branch 80 not taken.
✓ Branch 81 taken 1 times.
✓ Branch 83 taken 1 times.
✗ Branch 84 not taken.
✓ Branch 85 taken 1 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 1 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 1 times.
✗ Branch 90 not taken.
2 ))
643 1 call layer%read(unit, verbose=verbose_)
644
645 2 end function read_kipf_msgpass_layer
646 !###############################################################################
647
648
649 !###############################################################################
650 3 function get_attributes_kipf(this) result(attributes)
651 !! Get the attributes of the Kipf GCN layer (for ONNX export)
652 implicit none
653
654 ! Arguments
655 class(kipf_msgpass_layer_type), intent(in) :: this
656 !! Instance of the message passing layer
657 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
658 !! Attributes for ONNX export
659
660 ! Local variables
661 integer :: t
662 !! Loop index
663 character(256) :: buffer
664 !! Buffer for converting attributes to strings
665
666
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✓ Branch 21 taken 9 times.
✓ Branch 22 taken 3 times.
✓ Branch 23 taken 9 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 9 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 9 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 9 times.
12 allocate(attributes(3))
667
668 3 write(buffer, '(I0)') this%num_time_steps
669 attributes(1) = onnx_attribute_type( &
670
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 name='num_time_steps', type='int', val=trim(buffer))
671
672 3 buffer = ''
673
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 3 times.
9 do t = 0, this%num_time_steps
674
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
9 if(t .eq. 0)then
675
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
3 write(buffer, '(I0)') this%num_vertex_features(t)
676 else
677
3/6
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
3 write(buffer, '(A," ",I0)') trim(buffer), this%num_vertex_features(t)
678 end if
679 end do
680 attributes(2) = onnx_attribute_type( &
681
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 name='num_vertex_features', type='ints', val=trim(buffer))
682
683 attributes(3) = onnx_attribute_type( &
684 name='message_activation', type='string', &
685
6/12
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 val=trim(this%activation%name))
686
687 3 end function get_attributes_kipf
688 !###############################################################################
689
690
691 !###############################################################################
692 3 subroutine emit_onnx_nodes_kipf( &
693 this, prefix, &
694 3 nodes, num_nodes, max_nodes, &
695
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 inits, num_inits, max_inits, &
696 input_name, is_last_layer, format &
697 )
698 !! Emit ONNX JSON nodes for Kipf GCN layer
699 !!
700 !! Decomposes the Kipf message passing layer into standard ONNX ops:
701 !! Gather, ScatterElements, Mul, Pow, MatMul, activation
702 !!
703 !! Kipf GCN: H^(l+1) = sigma(D~^(-1/2) A~ D~^(-1/2) H^(l) W^(l))
704 !! Decomposed per timestep:
705 !! 1. Extract source/target indices from edge_index
706 !! 2. Gather source vertex features
707 !! 3. Compute normalisation coeff = (deg_src * deg_tgt)^(-0.5)
708 !! 4. Scale source features by coefficient
709 !! 5. Scatter-add to target vertices
710 !! 6. MatMul with weight W (transposed)
711 !! 7. Apply activation
712 use athena__onnx_msgpass_utils, only: emit_output_identity
713 implicit none
714
715 ! Arguments
716 class(kipf_msgpass_layer_type), intent(in) :: this
717 !! Instance of the layer
718 character(*), intent(in) :: prefix
719 !! Node name prefix (e.g. "node_2")
720 type(onnx_node_type), intent(inout), dimension(:) :: nodes
721 !! Accumulator for ONNX nodes
722 integer, intent(inout) :: num_nodes
723 !! Current number of nodes
724 integer, intent(in) :: max_nodes
725 !! Maximum capacity
726 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
727 !! Accumulator for ONNX initialisers
728 integer, intent(inout) :: num_inits
729 !! Current number of initialisers
730 integer, intent(in) :: max_inits
731 !! Maximum capacity
732 character(*), optional, intent(in) :: input_name
733 !! Unused sequential input name
734 logical, optional, intent(in) :: is_last_layer
735 !! Unused last-layer flag
736 integer, optional, intent(in) :: format
737 !! Unused export format selector
738
739 ! Local variables
740 integer :: t
741 !! Time-step index
742 character(128) :: cur_vertex_name
743 !! Current timestep output tensor name
744
745
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 3 times.
6 do t = 1, this%num_time_steps
746 call emit_kipf_timestep( &
747 prefix, t, &
748 6 this%num_vertex_features(t-1), &
749 6 this%num_vertex_features(t), &
750 30 this%params(t)%val(:,1), &
751 this%activation%name, &
752 nodes, num_nodes, max_nodes, &
753 inits, num_inits, max_inits, &
754 cur_vertex_name &
755
20/40
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 3 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 3 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 3 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 3 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 3 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 3 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 3 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 3 times.
6 )
756 end do
757
758 ! Kipf produces node-level output (no readout).
759 call emit_output_identity( &
760 prefix, trim(cur_vertex_name), this%activation%name, &
761
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 nodes, num_nodes)
762
763
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 end subroutine emit_onnx_nodes_kipf
764 !###############################################################################
765
766
767 !###############################################################################
768 3 subroutine emit_kipf_timestep( &
769 3 prefix, t, nv_in, nv_out, weight_data, activation_name, &
770 3 nodes, num_nodes, max_nodes, &
771
3/6
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
6 inits, num_inits, max_inits, vertex_out)
772 !! Emit ONNX nodes for one Kipf GCN time step.
773 use athena__onnx_utils, only: emit_node, emit_constant_int64, &
774 emit_constant_float, emit_activation_node
775 use athena__onnx_msgpass_utils, only: get_timestep_output_name, &
776 emit_edge_index_component, emit_scatter_aggregator, &
777 emit_weight_initialiser_2d
778 implicit none
779
780 ! Arguments
781 character(*), intent(in) :: prefix
782 integer, intent(in) :: t, nv_in, nv_out
783 real(real32), intent(in) :: weight_data(:)
784 character(*), intent(in) :: activation_name
785 type(onnx_node_type), intent(inout), dimension(:) :: nodes
786 integer, intent(inout) :: num_nodes
787 integer, intent(in) :: max_nodes
788 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
789 integer, intent(inout) :: num_inits
790 integer, intent(in) :: max_inits
791 character(128), intent(out) :: vertex_out
792
793 ! Local variables
794 character(128) :: tp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7
795 character(128) :: vertex_in, edge_index_in, degree_in
796 character(128) :: src_idx, target_idx, aggr_name
797 character(len=*), parameter :: onnx_axis0_attr = &
798 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
799 character(len=*), parameter :: onnx_transpose_10_attr = &
800 ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
801 '"type": "INTS"}]'
802 character(len=*), parameter :: onnx_cast_float_attr = &
803 ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]'
804
805
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tp, '(A,"_t",I0)') trim(prefix), t
806
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(vertex_in, '(A,"_vertex_in")') trim(prefix)
807
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(edge_index_in, '(A,"_edge_index_in")') trim(prefix)
808
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(degree_in, '(A,"_degree_in")') trim(prefix)
809
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if(t .gt. 1)then
810 call get_timestep_output_name( &
811 prefix, t-1, activation_name, '_mm_out', '', vertex_in)
812 end if
813
814 ! --- Step 1: Extract source and target indices from edge_index ---
815
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp1, '(A,"_idx0")') trim(tp)
816 call emit_constant_int64(trim(tmp1), [0], [1], &
817
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 nodes, num_nodes, inits, num_inits)
818
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp2, '(A,"_idx2")') trim(tp)
819 call emit_constant_int64(trim(tmp2), [2], [1], &
820
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 nodes, num_nodes, inits, num_inits)
821
822 call emit_edge_index_component( &
823
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 tp, edge_index_in, trim(tmp1), 'src', src_idx, nodes, num_nodes)
824 call emit_edge_index_component( &
825
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
3 tp, edge_index_in, trim(tmp2), 'tgt', target_idx, nodes, num_nodes)
826
827 ! --- Step 2: Gather source features and compute normalisation ---
828
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp1, '(A,"_src_feat")') trim(tp)
829 call emit_node('Gather', trim(tp)//'_gather_vfeat', &
830 trim(tmp1), onnx_axis0_attr, nodes, num_nodes, &
831
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(vertex_in), in2=trim(src_idx))
832
833
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp2, '(A,"_deg_f")') trim(tp)
834 call emit_node('Cast', trim(tp)//'_cast_deg', &
835 trim(tmp2), onnx_cast_float_attr, nodes, num_nodes, &
836
6/12
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 13 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
3 in1=trim(degree_in))
837
838
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp4, '(A,"_deg_src")') trim(tp)
839 call emit_node('Gather', trim(tp)//'_gather_deg_src', &
840 trim(tmp4), onnx_axis0_attr, nodes, num_nodes, &
841
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp2), in2=trim(src_idx))
842
843
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp6, '(A,"_deg_tgt")') trim(tp)
844 call emit_node('Gather', trim(tp)//'_gather_deg_tgt', &
845 trim(tmp6), onnx_axis0_attr, nodes, num_nodes, &
846
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp2), in2=trim(target_idx))
847
848
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp7, '(A,"_deg_prod")') trim(tp)
849 call emit_node('Mul', trim(tp)//'_mul_deg', &
850 trim(tmp7), '', nodes, num_nodes, &
851
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp4), in2=trim(tmp6))
852
853
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp2, '(A,"_neg_half")') trim(tp)
854 call emit_constant_float(trim(tmp2), [-0.5_real32], [1], &
855
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 nodes, num_nodes, inits, num_inits)
856
857
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp3, '(A,"_coeff")') trim(tp)
858 call emit_node('Pow', trim(tp)//'_pow_coeff', &
859 trim(tmp3), '', nodes, num_nodes, &
860
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp7), in2=trim(tmp2))
861
862 ! Unsqueeze coeff for broadcasting and scale the source features.
863
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp4, '(A,"_coeff_us")') trim(tp)
864
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp6, '(A,"_us_ax1")') trim(tp)
865 call emit_constant_int64(trim(tmp6), [1], [1], &
866
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 nodes, num_nodes, inits, num_inits)
867 call emit_node('Unsqueeze', trim(tp)//'_us_coeff', &
868 trim(tmp4), '', nodes, num_nodes, &
869
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp3), in2=trim(tmp6))
870
871
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp2, '(A,"_scaled_feat")') trim(tp)
872 call emit_node('Mul', trim(tp)//'_mul_coeff', &
873 trim(tmp2), '', nodes, num_nodes, &
874
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(tmp1), in2=trim(tmp4))
875
876 ! --- Step 3: Scatter-add normalised messages to target vertices ---
877 call emit_scatter_aggregator( &
878 tp, vertex_in, target_idx, trim(tmp2), nv_in, &
879
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 nodes, num_nodes, inits, num_inits, aggr_name)
880
881 ! --- Step 4: MatMul with weight W ---
882
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp1, '(A,"_W")') trim(tp)
883 call emit_weight_initialiser_2d( &
884
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✓ Branch 20 taken 3 times.
✗ Branch 21 not taken.
3 trim(tmp1), nv_out, nv_in, weight_data, inits, num_inits)
885
886
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp2, '(A,"_Wt")') trim(tp)
887 call emit_node('Transpose', trim(tp)//'_transpose_W', &
888 trim(tmp2), onnx_transpose_10_attr, nodes, num_nodes, &
889
6/12
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 13 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
3 in1=trim(tmp1))
890
891
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(tmp3, '(A,"_mm_out")') trim(tp)
892 call emit_node('MatMul', trim(tp)//'_matmul', &
893 trim(tmp3), '', nodes, num_nodes, &
894
7/14
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
3 in1=trim(aggr_name), in2=trim(tmp2))
895
896 ! --- Step 5: Activation ---
897
2/4
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 if(trim(activation_name) .ne. 'none')then
898 call emit_activation_node(activation_name, trim(tp), trim(tmp3), &
899
5/10
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
3 nodes, num_nodes, max_nodes)
900
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✓ Branch 13 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
3 vertex_out = trim(nodes(num_nodes)%outputs(1))
901 else
902 vertex_out = trim(tmp3)
903 end if
904
905
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 end subroutine emit_kipf_timestep
906 !###############################################################################
907
908
909 !##############################################################################!
910 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
911 !##############################################################################!
912
913
914 !##############################################################################!
915
1/2
✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
16 subroutine update_message_kipf(this, input)
916 !! Update the message
917 implicit none
918
919 ! Arguments
920 class(kipf_msgpass_layer_type), intent(inout), target :: this
921 !! Instance of the message passing layer
922 class(array_type), dimension(:,:), intent(in), target :: input
923 !! Input to the message passing layer
924
925 ! Local variables
926 integer :: s, t
927 !! Batch index, time step
928 type(array_type), pointer :: ptr1, ptr2, ptr3
929 !! Pointers to arrays
930
931
2/2
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 11 times.
16 if(allocated(this%output))then
932
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
5 if(size(this%output,2).ne.size(input,2))then
933 deallocate(this%output)
934 allocate(this%output(1,size(input,2)))
935 end if
936 else
937
25/46
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 11 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 11 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 11 times.
✓ Branch 18 taken 11 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 11 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 11 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 11 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 11 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 11 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 11 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 11 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 11 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 11 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 11 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 11 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 11 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 11 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 11 times.
✓ Branch 59 taken 11 times.
✓ Branch 60 taken 11 times.
✓ Branch 61 taken 11 times.
✓ Branch 62 taken 11 times.
33 allocate(this%output(1,size(input,2)))
938 end if
939
940
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✓ Branch 18 taken 16 times.
✓ Branch 19 taken 16 times.
32 do s = 1, size(input,2)
941
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
16 ptr1 => input(1,s)
942
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 16 times.
32 do t = 1, this%num_time_steps
943 ptr2 => kipf_propagate( &
944 ptr1, &
945 64 this%graph(s)%adj_ia, this%graph(s)%adj_ja &
946
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 )
947
948 ! this%z(t,s) = kipf_update( &
949 ! this%message(t,s), this%params(t), this%graph(s)%adj_ia &
950 ! )
951
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 ptr3 => matmul( this%params(t), ptr2 )
952 32 ptr1 => this%activation%apply( ptr3 )
953 end do
954
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 call this%output(1,s)%zero_grad()
955
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
16 call this%output(1,s)%assign_and_deallocate_source(ptr1)
956
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
32 this%output(1,s)%is_temporary = .false.
957 end do
958
959 16 end subroutine update_message_kipf
960 !###############################################################################
961
962
963 !###############################################################################
964 16 subroutine update_readout_kipf(this)
965 !! Update the readout (empty for node-level output)
966 implicit none
967 ! Arguments
968 class(kipf_msgpass_layer_type), intent(inout), target :: this
969 !! Instance of the message passing layer
970 16 end subroutine update_readout_kipf
971 !###############################################################################
972
973
72/167
✓ Branch 0 taken 60 times.
✓ Branch 1 taken 51 times.
✓ Branch 2 taken 60 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 51 times.
✓ Branch 5 taken 60 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✓ Branch 37 taken 60 times.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 51 times.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✓ Branch 43 taken 51 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 51 times.
✓ Branch 46 taken 51 times.
✓ Branch 47 taken 60 times.
✓ Branch 48 taken 51 times.
✓ Branch 49 taken 60 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 60 times.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✓ Branch 54 taken 60 times.
✓ Branch 55 taken 60 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 60 times.
✓ Branch 58 taken 60 times.
✓ Branch 59 taken 60 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 60 times.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✓ Branch 64 taken 60 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 60 times.
✓ Branch 67 taken 51 times.
✓ Branch 68 taken 60 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 111 times.
✓ Branch 71 taken 111 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 3 times.
✓ Branch 74 taken 48 times.
✓ Branch 75 taken 63 times.
✓ Branch 76 taken 3 times.
✓ Branch 77 taken 60 times.
✓ Branch 78 taken 3 times.
✗ Branch 79 not taken.
✓ Branch 80 taken 63 times.
✗ Branch 81 not taken.
✓ Branch 82 taken 63 times.
✓ Branch 83 taken 3 times.
✓ Branch 84 taken 60 times.
✓ Branch 85 taken 3 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 3 times.
✗ Branch 88 not taken.
✗ Branch 89 not taken.
✓ Branch 90 taken 3 times.
✗ Branch 91 not taken.
✓ Branch 92 taken 3 times.
✗ Branch 93 not taken.
✓ Branch 94 taken 3 times.
✗ Branch 95 not taken.
✗ Branch 96 not taken.
✗ Branch 97 not taken.
✗ Branch 98 not taken.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✗ Branch 101 not taken.
✓ Branch 102 taken 3 times.
✗ Branch 103 not taken.
✗ Branch 104 not taken.
✗ Branch 105 not taken.
✗ Branch 106 not taken.
✗ Branch 107 not taken.
✗ Branch 108 not taken.
✓ Branch 109 taken 48 times.
✓ Branch 110 taken 3 times.
✓ Branch 111 taken 51 times.
✗ Branch 112 not taken.
✓ Branch 113 taken 48 times.
✓ Branch 114 taken 3 times.
✓ Branch 115 taken 3 times.
✗ Branch 116 not taken.
✓ Branch 118 taken 51 times.
✗ Branch 119 not taken.
✗ Branch 120 not taken.
✓ Branch 121 taken 51 times.
✓ Branch 122 taken 51 times.
✗ Branch 123 not taken.
✗ Branch 124 not taken.
✓ Branch 125 taken 51 times.
✓ Branch 126 taken 51 times.
✗ Branch 127 not taken.
✗ Branch 128 not taken.
✓ Branch 129 taken 51 times.
✓ Branch 130 taken 51 times.
✗ Branch 131 not taken.
✓ Branch 132 taken 51 times.
✗ Branch 133 not taken.
✓ Branch 134 taken 51 times.
✗ Branch 135 not taken.
✗ Branch 136 not taken.
✓ Branch 137 taken 51 times.
✓ Branch 139 taken 51 times.
✗ Branch 140 not taken.
✗ Branch 141 not taken.
✓ Branch 142 taken 51 times.
✗ Branch 143 not taken.
✓ Branch 144 taken 51 times.
✓ Branch 146 taken 51 times.
✗ Branch 147 not taken.
✓ Branch 148 taken 51 times.
✗ Branch 149 not taken.
✗ Branch 150 not taken.
✗ Branch 151 not taken.
✓ Branch 153 taken 51 times.
✗ Branch 154 not taken.
✗ Branch 155 not taken.
✓ Branch 156 taken 51 times.
✗ Branch 157 not taken.
✓ Branch 158 taken 51 times.
✓ Branch 160 taken 51 times.
✗ Branch 161 not taken.
✗ Branch 162 not taken.
✓ Branch 163 taken 51 times.
✓ Branch 164 taken 51 times.
✗ Branch 165 not taken.
✗ Branch 166 not taken.
✓ Branch 167 taken 51 times.
✓ Branch 168 taken 51 times.
✗ Branch 169 not taken.
✗ Branch 170 not taken.
✓ Branch 171 taken 51 times.
570 end module athena__kipf_msgpass_layer
974