GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_onnx_msgpass_utils.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 71 111 64.0%
Functions: 0 0 -%
Branches: 231 616 37.5%

Line Branch Exec Source
1 module athena__onnx_msgpass_utils
2 !! Shared ONNX builder helpers for message-passing layers.
3 !!
4 !! This module factors out the repeated edge-index extraction,
5 !! scatter accumulation, weight export, and output naming logic used by
6 !! the Duvenaud and Kipf message-passing layers.
7 use coreutils, only: real32
8 use athena__misc_types, only: onnx_node_type, onnx_initialiser_type, &
9 onnx_tensor_type
10 use athena__onnx_utils, only: emit_node, emit_squeeze_node, &
11 emit_constant_int64, emit_constant_float, &
12 emit_constant_of_shape_float, emit_activation_node, &
13 col_to_row_major_2d
14 implicit none
15
16 private
17
18 public :: emit_msgpass_graph_inputs
19 public :: emit_output_identity
20 public :: get_timestep_output_name
21 public :: emit_edge_index_component
22 public :: emit_scatter_aggregator
23 public :: emit_weight_initialiser_2d
24 public :: emit_weight_initialiser_3d
25
26 character(len=*), parameter :: onnx_axis0_attr = &
27 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
28 character(len=*), parameter :: onnx_concat_axis0_attr = &
29 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
30 character(len=*), parameter :: onnx_concat_axis1_attr = &
31 ' "attribute": [{"name": "axis", "i": "1", "type": "INT"}]'
32 character(len=*), parameter :: onnx_softmax_axis0_attr = &
33 ' "attribute": [{"name": "axis", "i": "0", "type": "INT"}]'
34 character(len=*), parameter :: onnx_transpose_10_attr = &
35 ' "attribute": [{"name": "perm", "ints": ["1", "0"], ' // &
36 '"type": "INTS"}]'
37 character(len=*), parameter :: onnx_reduce_sum_attr = &
38 ' "attribute": [{"name": "keepdims", "i": "0", ' // &
39 '"type": "INT"}]'
40 character(len=*), parameter :: onnx_cast_float_attr = &
41 ' "attribute": [{"name": "to", "i": "1", "type": "INT"}]'
42 character(len=*), parameter :: onnx_cast_int64_attr = &
43 ' "attribute": [{"name": "to", "i": "7", "type": "INT"}]'
44 character(len=*), parameter :: onnx_scatter_add_attr = &
45 ' "attribute": [' // &
46 '{"name": "axis", "i": "0", "type": "INT"}, ' // &
47 '{"name": "reduction", "s": "YWRk", "type": "STRING"}]'
48
49 contains
50
51
52 !###############################################################################
53 subroutine emit_msgpass_graph_inputs(prefix, input_shape, graph_inputs, &
54 num_inputs)
55 !! Emit the standard graph input tensors used by message-passing layers.
56 !!
57 !! Adds vertex features, optional edge features, edge_index, and degree.
58 implicit none
59
60 ! Arguments
61 character(*), intent(in) :: prefix
62 !! Input name prefix (e.g. "input_1")
63 integer, dimension(:), intent(in) :: input_shape
64 !! Layer input shape [num_vertex_features, num_edge_features]
65 type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
66 !! Accumulator for graph inputs
67 integer, intent(inout) :: num_inputs
68 !! Current number of graph inputs
69
70 ! Vertex features: [num_nodes, nv]
71 call add_graph_input_tensor( &
72 graph_inputs, num_inputs, trim(prefix)//'_vertex', 1, &
73 -1, 'num_nodes', input_shape(1), '')
74
75 ! Edge features: [num_edges, ne]
76 if(input_shape(2) .gt. 0)then
77 call add_graph_input_tensor( &
78 graph_inputs, num_inputs, trim(prefix)//'_edge', 1, &
79 -1, 'num_edges', input_shape(2), '')
80 end if
81
82 ! Edge index: [3, num_csr_entries]
83 call add_graph_input_tensor( &
84 graph_inputs, num_inputs, trim(prefix)//'_edge_index', 7, &
85 3, '', -1, 'num_csr_entries')
86
87 ! Node degree: [num_nodes]
88 call add_graph_input_tensor( &
89 graph_inputs, num_inputs, trim(prefix)//'_degree', 7, &
90 -1, 'num_nodes')
91
92 end subroutine emit_msgpass_graph_inputs
93 !###############################################################################
94
95
96 !###############################################################################
97 4 subroutine emit_output_identity(prefix, source_name, activation_name, &
98 4 nodes, num_nodes)
99 !! Emit a final Identity node using the standard ATHENA output naming.
100 implicit none
101
102 ! Arguments
103 character(*), intent(in) :: prefix
104 !! Layer node prefix
105 character(*), intent(in) :: source_name
106 !! Source tensor to rename
107 character(*), intent(in) :: activation_name
108 !! Final activation name used in the exported output suffix
109 type(onnx_node_type), intent(inout), dimension(:) :: nodes
110 !! Accumulator for ONNX nodes
111 integer, intent(inout) :: num_nodes
112 !! Current number of nodes
113
114 ! Local variables
115 4 character(:), allocatable :: suffix
116
117
3/8
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
4 suffix = '_output'
118
2/4
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
4 if(trim(activation_name) .ne. 'none')then
119
5/10
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✓ Branch 8 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
4 suffix = '_' // trim(adjustl(activation_name)) // '_output'
120 end if
121 call emit_node('Identity', trim(prefix)//'_identity', &
122 trim(prefix)//trim(suffix), '', nodes, num_nodes, &
123
7/14
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✓ Branch 19 taken 4 times.
✗ Branch 20 not taken.
4 in1=trim(source_name))
124
125
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
8 end subroutine emit_output_identity
126 !###############################################################################
127
128
129 !###############################################################################
130 subroutine add_graph_input_tensor( &
131 graph_inputs, num_inputs, name, elem_type, &
132 dim1, dim_param1, dim2, dim_param2)
133 !! Add one graph input tensor declaration to the ONNX input list.
134 implicit none
135
136 ! Arguments
137 type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
138 integer, intent(inout) :: num_inputs
139 character(*), intent(in) :: name
140 integer, intent(in) :: elem_type, dim1
141 character(*), intent(in) :: dim_param1
142 integer, optional, intent(in) :: dim2
143 character(*), optional, intent(in) :: dim_param2
144
145 num_inputs = num_inputs + 1
146 graph_inputs(num_inputs)%name = trim(name)
147 graph_inputs(num_inputs)%elem_type = elem_type
148 if(present(dim2))then
149 allocate(graph_inputs(num_inputs)%dims(2))
150 allocate(graph_inputs(num_inputs)%dim_params(2))
151 graph_inputs(num_inputs)%dims = [ dim1, dim2 ]
152 graph_inputs(num_inputs)%dim_params(1) = dim_param1
153 graph_inputs(num_inputs)%dim_params(2) = dim_param2
154 else
155 allocate(graph_inputs(num_inputs)%dims(1))
156 allocate(graph_inputs(num_inputs)%dim_params(1))
157 graph_inputs(num_inputs)%dims(1) = dim1
158 graph_inputs(num_inputs)%dim_params(1) = dim_param1
159 end if
160
161 end subroutine add_graph_input_tensor
162 !###############################################################################
163
164
165 !###############################################################################
166 3 subroutine get_timestep_output_name( &
167 prefix, t, activation_name, inactive_suffix, activation_suffix, output)
168 !! Build the canonical ONNX output name for one exported timestep.
169 implicit none
170
171 ! Arguments
172 character(*), intent(in) :: prefix
173 integer, intent(in) :: t
174 character(*), intent(in) :: activation_name
175 character(*), intent(in) :: inactive_suffix, activation_suffix
176 character(128), intent(out) :: output
177
178 ! Local variables
179 character(128) :: step_prefix
180
181
1/2
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 write(step_prefix, '(A,"_t",I0)') trim(prefix), t
182
2/4
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 if(trim(activation_name) .ne. 'none')then
183
2/4
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 output = trim(step_prefix)
184
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(len_trim(activation_suffix) .gt. 0)then
185
3/6
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
3 output = trim(output) // trim(activation_suffix)
186 end if
187 output = trim(output) // '_' // trim(adjustl(activation_name)) // &
188
3/6
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
3 '_output'
189 else
190 output = trim(step_prefix) // trim(inactive_suffix)
191 end if
192
193
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 end subroutine get_timestep_output_name
194 !###############################################################################
195
196
197 !###############################################################################
198 12 subroutine emit_edge_index_component( &
199 tp, edge_index_in, index_name, tag, component_out, &
200
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
12 nodes, num_nodes)
201 !! Gather one edge_index row and squeeze it into a vector.
202 implicit none
203
204 ! Arguments
205 character(*), intent(in) :: tp, edge_index_in, index_name, tag
206 character(128), intent(out) :: component_out
207 type(onnx_node_type), intent(inout), dimension(:) :: nodes
208 integer, intent(inout) :: num_nodes
209
210 ! Local variables
211 character(128) :: raw_name
212
213
2/4
✓ Branch 3 taken 12 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
12 write(raw_name, '(A,"_",A,"_raw")') trim(tp), trim(tag)
214
2/4
✓ Branch 3 taken 12 times.
✗ Branch 4 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
12 write(component_out, '(A,"_",A)') trim(tp), trim(tag)
215 call emit_node('Gather', trim(tp)//'_gather_'//trim(tag), &
216 trim(raw_name), onnx_axis0_attr, nodes, num_nodes, &
217
8/16
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 12 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 12 times.
✓ Branch 18 taken 12 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 12 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 23 not taken.
12 in1=trim(edge_index_in), in2=trim(index_name))
218 call emit_squeeze_node(trim(tp)//'_sq_'//trim(tag), &
219 trim(raw_name), trim(tp)//'_idx0', trim(component_out), &
220
8/16
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✓ Branch 11 taken 12 times.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 12 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 12 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 12 times.
✓ Branch 24 taken 12 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 12 times.
✗ Branch 27 not taken.
12 nodes, num_nodes)
221
222
1/2
✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
12 end subroutine emit_edge_index_component
223 !###############################################################################
224
225
226 !###############################################################################
227 5 subroutine emit_scatter_aggregator( &
228 tp, vertex_in, target_in, message_in, feature_dim, &
229
2/4
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
5 nodes, num_nodes, inits, num_inits, aggr_out)
230 !! Emit the zero-initialise, expand, and scatter-add aggregation block.
231 implicit none
232
233 ! Arguments
234 character(*), intent(in) :: tp, vertex_in, target_in, message_in
235 integer, intent(in) :: feature_dim
236 type(onnx_node_type), intent(inout), dimension(:) :: nodes
237 integer, intent(inout) :: num_nodes
238 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
239 integer, intent(inout) :: num_inits
240 character(128), intent(out) :: aggr_out
241
242 ! Local variables
243 character(128) :: shape_name, nnodes_idx, nnodes_name
244 character(128) :: feat_dim_name, aggr_shape, zeros_name
245 character(128) :: target_us, axes1_name, msg_shape, target_exp
246
247 ! Get num_nodes from shape of vertex_in.
248
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(shape_name, '(A,"_vshape")') trim(tp)
249 call emit_node('Shape', trim(tp)//'_shape_v', &
250 trim(shape_name), '', nodes, num_nodes, &
251
6/12
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 13 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 16 not taken.
5 in1=trim(vertex_in))
252
253
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(nnodes_idx, '(A,"_nnodes_idx")') trim(tp)
254 call emit_constant_int64(trim(nnodes_idx), [0], [1], &
255
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 5 times.
✓ Branch 20 taken 5 times.
✗ Branch 21 not taken.
5 nodes, num_nodes, inits, num_inits)
256
257
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(nnodes_name, '(A,"_nnodes")') trim(tp)
258 call emit_node('Gather', trim(tp)//'_gather_nn', &
259 trim(nnodes_name), onnx_axis0_attr, nodes, num_nodes, &
260
7/14
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 14 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 19 not taken.
5 in1=trim(shape_name), in2=trim(nnodes_idx))
261
262 ! Concat [num_nodes, feature_dim] to create the scatter target shape.
263
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(feat_dim_name, '(A,"_feat_dim")') trim(tp)
264 call emit_constant_int64(trim(feat_dim_name), [feature_dim], [1], &
265
9/16
✓ Branch 1 taken 5 times.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✓ Branch 22 taken 5 times.
✗ Branch 23 not taken.
10 nodes, num_nodes, inits, num_inits)
266
267
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(aggr_shape, '(A,"_aggr_shape")') trim(tp)
268 call emit_node('Concat', trim(tp)//'_cat_shape', &
269 trim(aggr_shape), onnx_concat_axis0_attr, nodes, num_nodes, &
270
7/14
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 14 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 19 not taken.
5 in1=trim(nnodes_name), in2=trim(feat_dim_name))
271
272 ! ConstantOfShape creates the zero-filled aggregation buffer.
273
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(zeros_name, '(A,"_zeros")') trim(tp)
274 call emit_constant_of_shape_float(trim(tp)//'_zeros', &
275 trim(aggr_shape), 0.0_real32, trim(zeros_name), &
276
9/18
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✓ Branch 25 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✗ Branch 28 not taken.
5 nodes, num_nodes, inits, num_inits)
277
278
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(target_us, '(A,"_tgt_us")') trim(tp)
279
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(axes1_name, '(A,"_us_ax1")') trim(tp)
280 call emit_constant_int64(trim(axes1_name), [1], [1], &
281
7/14
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 5 times.
✓ Branch 20 taken 5 times.
✗ Branch 21 not taken.
5 nodes, num_nodes, inits, num_inits)
282 call emit_node('Unsqueeze', trim(tp)//'_us_tgt', &
283 trim(target_us), '', nodes, num_nodes, &
284
7/14
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 14 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 19 not taken.
5 in1=trim(target_in), in2=trim(axes1_name))
285
286 ! Expand target indices to match the message rank for ScatterElements.
287
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(msg_shape, '(A,"_msg_shape")') trim(tp)
288 call emit_node('Shape', trim(tp)//'_shape_msg', &
289 trim(msg_shape), '', nodes, num_nodes, &
290
6/12
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 13 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 16 not taken.
5 in1=trim(message_in))
291
292
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(target_exp, '(A,"_tgt_exp")') trim(tp)
293 call emit_node('Expand', trim(tp)//'_expand_tgt', &
294 trim(target_exp), '', nodes, num_nodes, &
295
7/14
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 14 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 19 not taken.
5 in1=trim(target_us), in2=trim(msg_shape))
296
297 ! Scatter-add edge messages into the target-vertex slots.
298
1/2
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
5 write(aggr_out, '(A,"_aggr")') trim(tp)
299 call emit_node('ScatterElements', trim(tp)//'_scatter_add', &
300 trim(aggr_out), onnx_scatter_add_attr, nodes, num_nodes, &
301
8/16
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✓ Branch 15 taken 5 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 22 not taken.
5 in1=trim(zeros_name), in2=trim(target_exp), in3=trim(message_in))
302
303
1/2
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
5 end subroutine emit_scatter_aggregator
304 !###############################################################################
305
306
307 !###############################################################################
308 5 subroutine emit_weight_initialiser_2d( &
309
1/2
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
5 name, nrows, ncols, weight_data, inits, num_inits)
310 !! Store a 2D weight matrix as an ONNX initialiser in row-major order.
311 implicit none
312
313 ! Arguments
314 character(*), intent(in) :: name
315 integer, intent(in) :: nrows, ncols
316 real(real32), intent(in) :: weight_data(:)
317 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
318 integer, intent(inout) :: num_inits
319
320 5 num_inits = num_inits + 1
321
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 10 not taken.
5 inits(num_inits)%name = trim(name)
322
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
5 inits(num_inits)%data_type = 1
323
5/10
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
5 allocate(inits(num_inits)%dims(2))
324
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✓ Branch 6 taken 5 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 10 times.
✓ Branch 13 taken 5 times.
15 inits(num_inits)%dims = [ nrows, ncols ]
325
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✓ Branch 15 taken 5 times.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
5 allocate(inits(num_inits)%data(size(weight_data)))
326
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
5 call col_to_row_major_2d(weight_data, inits(num_inits)%data, nrows, ncols)
327
328
1/2
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
5 end subroutine emit_weight_initialiser_2d
329 !###############################################################################
330
331
332 !###############################################################################
333 2 subroutine emit_weight_initialiser_3d( &
334
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 name, nslices, nrows, ncols, weight_data, inits, num_inits)
335 !! Store a stacked bank of 2D weight matrices as one ONNX tensor.
336 implicit none
337
338 ! Arguments
339 character(*), intent(in) :: name
340 integer, intent(in) :: nslices, nrows, ncols
341 real(real32), intent(in) :: weight_data(:)
342 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
343 integer, intent(inout) :: num_inits
344
345 2 num_inits = num_inits + 1
346
4/8
✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
2 inits(num_inits)%name = trim(name)
347
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 inits(num_inits)%data_type = 1
348
5/10
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
2 allocate(inits(num_inits)%dims(3))
349
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 6 times.
✓ Branch 13 taken 2 times.
8 inits(num_inits)%dims = [ nslices, nrows, ncols ]
350
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✓ Branch 15 taken 2 times.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
2 allocate(inits(num_inits)%data(size(weight_data)))
351
352 ! Transpose each 2D slice from column-major to row-major before export.
353 block
354 integer :: d, slice_size
355 2 slice_size = nrows * ncols
356
2/2
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
10 do d = 1, nslices
357 call col_to_row_major_2d( &
358 weight_data((d-1)*slice_size+1 : d*slice_size), &
359
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
16 inits(num_inits)%data((d-1)*slice_size+1 : d*slice_size), &
360
11/22
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 8 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 8 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 8 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 8 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 8 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 8 times.
26 nrows, ncols)
361 end do
362 end block
363
364
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 end subroutine emit_weight_initialiser_3d
365 !###############################################################################
366
367 end module athena__onnx_msgpass_utils
368