GCC Code Coverage Report


Directory: src/athena/
File: athena_onnx_sub.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 1 1 100.0%
Functions: 0 0 -%
Branches: 15 24 62.5%

Line Branch Exec Source
1 submodule(athena__onnx) athena__onnx_submodule
2 !! Submodule containing implementations for ONNX operations
3 use athena__base_layer, only: base_layer_type, learnable_layer_type
4 use athena__misc_types, only: &
5 onnx_attribute_type, onnx_node_type, onnx_initialiser_type, onnx_tensor_type
6 use coreutils, only: real32, to_lower, to_upper, to_camel_case, icount
7 use athena__tools_infile, only: assign_val, assign_vec, allocate_and_assign_vec
8
9 contains
10
11 !###############################################################################
12 module subroutine write_onnx(file, network)
13 !! Export the network to ONNX format
14 implicit none
15
16 ! Arguments
17 class(network_type), intent(in) :: network
18 !! Instance of network
19 character(*), intent(in) :: file
20 !! File to export the network to
21
22 ! Local variables
23 integer :: unit, i, j, layer_id, input_layer_id
24 !! Unit number and loop indices
25 character(256) :: layer_name
26 !! Layer name for ONNX
27 character(64) :: node_name, input_name, tmp_input_name
28 !! Node name
29 character(:), allocatable :: suffix
30 !!! Suffix for input names
31
32 open(newunit=unit, file=file, status='replace')
33
34 ! Write ONNX header
35 write(unit, '(A)') 'ir_version: 8'
36 write(unit, '(A)') 'producer_name: "Athena"'
37 write(unit, '(A)') 'producer_version: "1.0"'
38 write(unit, '(A)') 'domain: "ai.onnx"'
39 write(unit, '(A)') 'model_version: 1'
40 write(unit, '(A)') 'doc_string: "Athena neural network model"'
41 write(unit, '(A)') ''
42
43 ! Write graph definition
44 write(unit, '(A)') 'graph {'
45 write(unit, '(A)') ' name: "athena_network"'
46 write(unit, '(A)') ''
47
48 ! Write nodes (layers)
49 write(unit, '(A)') ' # Nodes'
50 do i = 1, network%auto_graph%num_vertices
51 layer_id = network%auto_graph%vertex(network%vertex_order(i))%id
52 write(node_name, '("node_", I0)') network%model(layer_id)%layer%id
53
54 select case(trim(network%model(layer_id)%layer%type))
55 case('inpt')
56 layer_name = 'Input'
57 cycle
58 case('full')
59 layer_name = 'MatMul'
60 case('conv')
61 layer_name = 'Conv'
62 case('pool')
63 layer_name = to_camel_case( &
64 trim(adjustl(network%model(layer_id)%layer%subtype))//"_"//&
65 trim(adjustl(network%model(layer_id)%layer%type)), &
66 capitalise_first_letter = .true. &
67 )
68 case('actv')
69 layer_name = to_camel_case( &
70 adjustl(network%model(layer_id)%layer%subtype), &
71 capitalise_first_letter = .true. &
72 )
73 case('flat')
74 layer_name = 'Flatten'
75 case('batc')
76 layer_name = 'BatchNormalization'
77 case('drop')
78 layer_name = 'Dropout'
79 case('msgp')
80 layer_name = 'GNNLayer'
81 case default
82 layer_name = 'Unknown'
83 end select
84
85 write(unit, '(A)') ' node {'
86 write(unit, '(A,A,A)') ' name: "', trim(node_name), '"'
87 write(unit, '(A,A,A)') ' op_type: "', trim(layer_name), '"'
88
89 ! Write input connections
90 if(all(network%auto_graph%adjacency(:,network%vertex_order(i)).eq.0))then
91 cycle
92 ! write(unit, '(A,I0,A)') ' input: "input_',network%model(layer_id)%layer%id,'"'
93 else
94 do j = 1, network%auto_graph%num_vertices
95 input_layer_id = network%auto_graph%vertex(j)%id
96 if(network%auto_graph%adjacency(j,network%vertex_order(i)).eq.0) cycle
97 if(all(network%auto_graph%adjacency(:,j).eq.0))then
98 write(input_name,'("input_",I0)') &
99 network%model(input_layer_id)%layer%id
100 suffix = ''
101 else
102 write(input_name,'("node_",I0)') &
103 network%model(input_layer_id)%layer%id
104 suffix = '_output'
105 ! check if activation function is used, if so adjust suffix
106 select type(prev_layer => network%model(input_layer_id)%layer)
107 class is(learnable_layer_type)
108 if(prev_layer%activation%name.ne."none")then
109 suffix = '_' // trim(adjustl(prev_layer%activation%name)) &
110 // '_output'
111 end if
112 end select
113 end if
114 if(network%model(layer_id)%layer%use_graph_input)then
115 write(tmp_input_name,'(A,A,A)') &
116 trim(adjustl(input_name)), '_vertex', suffix
117 write(unit,'(4X,"input: """,A,"""")') trim(adjustl(tmp_input_name))
118 if(network%model(layer_id)%layer%input_shape(2) .gt. 0)then
119 write(tmp_input_name,'(A,A,A)') &
120 trim(adjustl(input_name)), '_edge', suffix
121 write(unit,'(4X,"input: """,A,"""")') trim(adjustl(tmp_input_name))
122 end if
123 else
124 write(unit,'(4X,"input: """,A,A,"""")') &
125 trim(adjustl(input_name)), suffix
126 end if
127 end do
128 end if
129 select type(layer => network%model(layer_id)%layer)
130 class is(learnable_layer_type)
131 do j = 1, size(layer%params)
132 write(unit, '(4X,"input: ""node_",I0,"_param",I0,"""")') &
133 network%model(layer_id)%layer%id, j
134 end do
135 end select
136 suffix = ''
137
138 ! Write output
139 if(network%model(layer_id)%layer%use_graph_output)then
140 write(unit, '(4X,"output: ""node_",I0,"_vertex_output",A,"""")') &
141 network%model(layer_id)%layer%id, trim(adjustl(suffix))
142 write(unit, '(4X,"output: ""node_",I0,"_edge_output",A,"""")') &
143 network%model(layer_id)%layer%id, trim(adjustl(suffix))
144 else
145 write(unit, '(4X,"output: ""node_",I0,"_output",A,"""")') &
146 network%model(layer_id)%layer%id, trim(adjustl(suffix))
147 end if
148
149 call write_onnx_attributes(unit, network%model(layer_id)%layer)
150
151 write(unit, '(A)') ' }'
152 write(unit, '(A)') ''
153
154 select type(layer => network%model(layer_id)%layer)
155 class is(learnable_layer_type)
156 call write_onnx_initialisers(unit, layer, prefix = trim(node_name) )
157 if(layer%activation%name.ne."none")then
158 if(layer%use_graph_output)then
159 call write_onnx_function( &
160 unit, layer%activation%name, &
161 prefix = trim(node_name)//'_vertex' &
162 )
163 if(network%model(layer_id)%layer%input_shape(2) .gt. 0)then
164 call write_onnx_function( &
165 unit, layer%activation%name, &
166 prefix = trim(node_name)//'_edge' &
167 )
168 end if
169 else
170 call write_onnx_function( &
171 unit, layer%activation%name, &
172 prefix = trim(node_name) &
173 )
174 end if
175 end if
176 end select
177 end do
178
179
180 ! write all layer output shapes
181 do i = 1, network%auto_graph%num_vertices
182 layer_id = network%auto_graph%vertex(network%vertex_order(i))%id
183 if(.not.allocated(network%model(layer_id)%layer%output_shape)) cycle
184 if(network%model(layer_id)%layer%use_graph_output)then
185 write(node_name, '("node_",I0,"_vertex_output")') &
186 network%model(layer_id)%layer%id
187 call write_onnx_tensor( &
188 unit, &
189 "value_info", &
190 trim(adjustl(node_name)), &
191 [ network%model(layer_id)%layer%output_shape(1) ], &
192 network%batch_size &
193 )
194 if(network%model(layer_id)%layer%output_shape(2) .gt. 0)then
195 write(node_name, '("node_",I0,"_edge_output")') &
196 network%model(layer_id)%layer%id
197 call write_onnx_tensor( &
198 unit, &
199 "value_info", &
200 trim(adjustl(node_name)), &
201 [ network%model(layer_id)%layer%output_shape(2) ], &
202 network%batch_size &
203 )
204 end if
205 else
206 write(node_name, '("node_",I0,"_output")') network%model(layer_id)%layer%id
207 call write_onnx_tensor( &
208 unit, &
209 "value_info", &
210 trim(adjustl(node_name)), &
211 network%model(layer_id)%layer%output_shape, &
212 network%batch_size &
213 )
214 end if
215 end do
216
217 ! Write inputs
218 write(unit, '(A)') ' # Inputs'
219 do i = 1, size(network%root_vertices, dim=1)
220 layer_id = network%auto_graph%vertex(network%root_vertices(i))%id
221 if(network%model(layer_id)%layer%use_graph_output)then
222 write(node_name, '("input_",I0,"_vertex")') network%model(layer_id)%layer%id
223 call write_onnx_tensor( &
224 unit, &
225 "input", &
226 trim(adjustl(node_name)), &
227 [ network%model(layer_id)%layer%input_shape(1) ], &
228 network%batch_size &
229 )
230 if(network%model(layer_id)%layer%input_shape(2) .gt. 0)then
231 write(node_name, '("input_",I0,"_edge")') network%model(layer_id)%layer%id
232 call write_onnx_tensor( &
233 unit, &
234 "input", &
235 trim(adjustl(node_name)), &
236 [ network%model(layer_id)%layer%input_shape(2) ], &
237 network%batch_size &
238 )
239 end if
240 else
241 write(node_name, '("input_",I0)') network%model(layer_id)%layer%id
242 call write_onnx_tensor( &
243 unit, &
244 "input", &
245 trim(adjustl(node_name)), &
246 network%model(layer_id)%layer%input_shape, &
247 network%batch_size &
248 )
249 end if
250 end do
251
252 ! Write outputs
253 write(unit, '(A)') ' # Outputs'
254 do i = 1, size(network%leaf_vertices, dim=1)
255 layer_id = network%auto_graph%vertex(network%leaf_vertices(i))%id
256 if(network%model(layer_id)%layer%use_graph_output)then
257 write(node_name, '("node_",I0,"_vertex_output")') &
258 network%model(layer_id)%layer%id
259 call write_onnx_tensor( &
260 unit, &
261 "output", &
262 trim(adjustl(node_name)), &
263 [ network%model(layer_id)%layer%output_shape(1) ], &
264 network%batch_size &
265 )
266 if(network%model(layer_id)%layer%output_shape(2) .gt. 0)then
267 write(node_name, '("node_",I0,"_edge_output")') &
268 network%model(layer_id)%layer%id
269 call write_onnx_tensor( &
270 unit, &
271 "output", &
272 trim(adjustl(node_name)), &
273 [ network%model(layer_id)%layer%output_shape(2) ], &
274 network%batch_size &
275 )
276 end if
277 else
278 select type(layer => network%model(layer_id)%layer)
279 class is(learnable_layer_type)
280 if(layer%activation%name.eq."none")then
281 suffix = ''
282 else
283 suffix = '_' // trim(adjustl(layer%activation%name))
284 end if
285 class default
286 suffix = ''
287 end select
288 write(node_name, '("node_",I0,A,"_output")') &
289 network%model(layer_id)%layer%id, trim(adjustl(suffix))
290 call write_onnx_tensor( &
291 unit, &
292 "output", &
293 trim(adjustl(node_name)), &
294 network%model(layer_id)%layer%output_shape, &
295 network%batch_size &
296 )
297 end if
298 end do
299
300 write(unit, '(A)') '}'
301
302 ! Write ONNX footer
303 write(unit, '(A)') 'opset_import {'
304 write(unit, '(A)') ' domain: "ai.onnx"'
305 write(unit, '(A,I0)') ' version: ', 13 ! ONNX version
306 write(unit, '(A)') '}'
307
308 close(unit)
309
310 end subroutine write_onnx
311 !###############################################################################
312
313
314 !###############################################################################
315 module function read_onnx(file, verbose) result(network)
316 !! Import a network from ONNX format
317 implicit none
318
319 ! Arguments
320 character(*), intent(in) :: file
321 !! File to import the network from
322 integer, optional, intent(in) :: verbose
323 !! Verbosity level (0=quiet, 1=normal, 2=debug)
324
325 ! Return value
326 type(network_type) :: network
327 !! Network instance
328
329 ! Local variables
330 integer :: unit, stat, i, j, k, itmp1
331 integer :: num_nodes, num_inputs, num_outputs, num_value_infos
332 character(1024) :: trimmed_line, line
333 character(256) :: op_type, node_name, temp_str
334 integer, allocatable, dimension(:) :: dims
335 real(real32), allocatable, dimension(:) :: float_data
336 logical :: in_node, in_initialiser, reading_dims, reading_data, in_input, in_output
337 integer :: node_id
338 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
339
340 integer :: verbose_
341 character(1024) :: buffer1
342 character(64) :: buffer2
343 character(64), allocatable :: inputs(:), outputs(:)
344
345 ! Node information storage
346 type(onnx_node_type), allocatable, dimension(:) :: nodes
347
348 ! Initialiser storage
349 type(onnx_initialiser_type), allocatable, dimension(:) :: initialisers
350 integer :: num_initialisers
351
352 ! Tensor info storage (inputs, outputs)
353 type(onnx_tensor_type), allocatable, dimension(:) :: input_tensors, &
354 output_tensors, value_infos
355
356 verbose_ = 0
357 if(present(verbose)) verbose_ = verbose
358
359 open(newunit=unit, file=file, status='old', action='read', iostat=stat)
360 if(stat .ne. 0)then
361 write(*,*) "ERROR: Could not open file: ", trim(file)
362 return
363 end if
364
365 ! Initialise counters
366 num_nodes = 0
367 num_initialisers = 0
368 num_inputs = 0
369 num_outputs = 0
370 num_value_infos = 0
371 in_node = .false.
372 in_initialiser = .false.
373 in_input = .false.
374 in_output = .false.
375 reading_dims = .false.
376 reading_data = .false.
377
378 ! First pass: count nodes, initialisers, and tensors
379 do
380 !call read_full_line(unit, line)
381 read(unit, '(A)', iostat=stat) line
382 if(stat .ne. 0) exit
383
384 trimmed_line = adjustl(trim(line))
385
386 if(index(trimmed_line, 'node {') .gt. 0)then
387 num_nodes = num_nodes + 1
388 elseif(index(trimmed_line, 'initializer {') .gt. 0)then
389 num_initialisers = num_initialisers + 1
390 elseif(index(trimmed_line, 'input {') .gt. 0)then
391 num_inputs = num_inputs + 1
392 elseif(index(trimmed_line, 'output {') .gt. 0)then
393 num_outputs = num_outputs + 1
394 elseif(index(trimmed_line, 'value_info {') .gt. 0)then
395 num_value_infos = num_value_infos + 1
396 end if
397 end do
398
399 ! Allocate storage
400 allocate(nodes(num_nodes))
401 allocate(initialisers(num_initialisers))
402 allocate(input_tensors(num_inputs))
403 allocate(output_tensors(num_outputs))
404 allocate(value_infos(num_value_infos))
405
406 ! Reset file for second pass
407 rewind(unit)
408
409 num_nodes = 0
410 num_initialisers = 0
411 num_inputs = 0
412 num_outputs = 0
413 num_value_infos = 0
414
415 ! Initialise node structures
416 do i = 1, num_nodes
417 nodes(i)%num_inputs = 0
418 nodes(i)%num_outputs = 0
419 nodes(i)%op_type = ""
420 nodes(i)%name = ""
421 end do
422
423 ! Second pass: parse file content
424 do
425 read(unit, '(A)', iostat=stat) line
426 if(stat .ne. 0) exit
427
428 trimmed_line = trim(adjustl(line))
429 buffer1 = trimmed_line
430
431 ! Parse nodes
432 if(index(trimmed_line, 'node {') .gt. 0)then
433 in_node = .true.
434 num_nodes = num_nodes + 1
435 nodes(num_nodes)%num_inputs = 0
436 nodes(num_nodes)%num_outputs = 0
437 allocate(inputs(0))
438 allocate(outputs(0))
439 allocate(attributes(0))
440
441 elseif(in_node .and. index(trimmed_line, '}') .gt. 0)then
442 in_node = .false.
443 if(size(attributes) .gt. 0)then
444 allocate(nodes(num_nodes)%attributes(size(attributes)))
445 do i = 1, size(attributes)
446 nodes(num_nodes)%attributes(i) = attributes(i)
447 end do
448 end if
449 if(size(inputs) .gt. 0)then
450 allocate(nodes(num_nodes)%inputs(size(inputs)))
451 do i = 1, size(inputs)
452 nodes(num_nodes)%inputs(i) = inputs(i)
453 end do
454 end if
455 if(size(outputs) .gt. 0)then
456 allocate(nodes(num_nodes)%outputs(size(outputs)))
457 do i = 1, size(outputs)
458 nodes(num_nodes)%outputs(i) = outputs(i)
459 end do
460 end if
461 deallocate(attributes)
462 deallocate(inputs)
463 deallocate(outputs)
464
465 elseif(in_node)then
466 if(index(trimmed_line, 'name:') .gt. 0)then
467 call assign_val(buffer1, nodes(num_nodes)%name, itmp1, fs=":")
468 elseif(index(trimmed_line, 'op_type:') .gt. 0)then
469 call assign_val(buffer1, &
470 nodes(num_nodes)%op_type, itmp1, fs=":")
471 elseif(index(trimmed_line, 'input:') .gt. 0)then
472 nodes(num_nodes)%num_inputs = &
473 nodes(num_nodes)%num_inputs + 1
474 call assign_val(buffer1, buffer2, itmp1, fs=":")
475 !buffer2 = trim(adjustl(trimmed_line(index(trimmed_line, 'input:') + 6:)))
476 inputs = [ inputs, buffer2 ]
477 elseif(index(trimmed_line, 'output:') .gt. 0)then
478 nodes(num_nodes)%num_outputs = &
479 nodes(num_nodes)%num_outputs + 1
480 call assign_val(buffer1, buffer2, itmp1, fs=":")
481 !buffer2 = trim(adjustl(trimmed_line(index(trimmed_line, 'output:') + 7:)))
482 outputs = [ outputs, buffer2 ]
483 elseif(index(trimmed_line, 'attribute {') .gt. 0)then
484 attributes = [attributes, read_attribute(unit)]
485 end if
486 end if
487
488 ! Parse initialisers
489 if(index(trimmed_line, 'initializer {') .gt. 0)then
490 in_initialiser = .true.
491 num_initialisers = num_initialisers + 1
492 reading_dims = .false.
493 reading_data = .false.
494
495 elseif(in_initialiser .and. index(trimmed_line, '}') .gt. 0)then
496 in_initialiser = .false.
497 reading_dims = .false.
498 reading_data = .false.
499
500 elseif(in_initialiser)then
501 if(index(trimmed_line, 'name:') .gt. 0)then
502 call assign_val(buffer1, &
503 initialisers(num_initialisers)%name, itmp1, fs=":")
504 elseif(index(trimmed_line, 'dims:') .gt. 0)then
505 if(.not. reading_dims)then
506 reading_dims = .true.
507 if(allocated(dims)) deallocate(dims)
508 allocate(dims(0))
509 end if
510 call assign_val(buffer1, j, itmp1, fs=":")
511 dims = [dims, j]
512 initialisers(num_initialisers)%dims = dims
513 elseif(index(trimmed_line, 'float_data:') .gt. 0)then
514 reading_data = .true.
515 allocate(initialisers(num_initialisers)%data(0))
516 do while(reading_data)
517 read(unit, '(A)', iostat=stat) line
518 if(stat .ne. 0) exit
519 trimmed_line = trim(adjustl(line))
520 if(index(trimmed_line, 'float_data:') .gt. 0)then
521 trimmed_line = trimmed_line(index(trimmed_line, 'float_data:') + 11:)
522 end if
523 if(index(trimmed_line, ']') .gt. 0)then
524 reading_data = .false.
525 elseif(trim(adjustl(trimmed_line)) .ne. '')then
526 call allocate_and_assign_vec(trimmed_line, float_data, fs=":")
527 initialisers(num_initialisers)%data = &
528 [initialisers(num_initialisers)%data, float_data]
529 deallocate(float_data)
530 end if
531 end do
532
533 end if
534 end if
535
536 ! Parse input tensors
537 if(index(trimmed_line, 'input {') .gt. 0 .and. &
538 .not. in_node .and. .not. in_initialiser &
539 )then
540 in_input = .true.
541 num_inputs = num_inputs + 1
542 input_tensors(num_inputs) = read_input_output(unit)
543 in_input = .false.
544 end if
545
546 ! Parse output tensors
547 if(index(trimmed_line, 'output {') .gt. 0 .and. &
548 .not. in_node .and. .not. in_initialiser &
549 )then
550 in_output = .true.
551 num_outputs = num_outputs + 1
552 output_tensors(num_outputs) = read_input_output(unit)
553 in_output = .false.
554 end if
555
556 ! Parse value_info tensors
557 if(index(trimmed_line, 'value_info {') .gt. 0 .and. &
558 .not. in_node .and. .not. in_initialiser &
559 )then
560 num_value_infos = num_value_infos + 1
561 value_infos(num_value_infos) = read_input_output(unit)
562 end if
563 end do
564
565 close(unit)
566
567 ! Now construct the network from parsed information
568 call network%build_from_onnx( &
569 nodes, initialisers, input_tensors, value_infos, &
570 verbose=verbose_ &
571 )
572
573 end function read_onnx
574 !###############################################################################
575
576
577 !###############################################################################
578 function read_attribute(unit) result(attr)
579 !! Reads an entire attribute block from an ONNX file
580 !! Handles multi-line attributes (e.g., multiple ints or floats)
581 implicit none
582 integer, intent(in) :: unit
583 type(onnx_attribute_type) :: attr
584 character(1024) :: line, trimmed_line
585 character(1024) :: value_buffer
586 character(64) :: key, attr_type_key
587 character(256) :: value_str
588 integer :: stat, colon_pos
589 logical :: done
590
591 ! Initialise attribute
592 attr%name = ""
593 attr%type = ""
594 allocate(character(0) :: attr%val)
595 value_buffer = ""
596
597 ! Read the opening "attribute {" line (already read by caller, so we're inside)
598 done = .false.
599
600 do while(.not. done)
601 read(unit, '(A)', iostat=stat) line
602 if(stat .ne. 0) exit
603
604 trimmed_line = adjustl(trim(line))
605 if(trim(trimmed_line) .eq. 'attribute {') cycle
606
607 ! Check for closing brace
608 if(index(trimmed_line, '}') .gt. 0) then
609 done = .true.
610 exit
611 end if
612
613 ! Parse line with colon separator
614 colon_pos = index(trimmed_line, ':')
615 if(colon_pos .gt. 0) then
616 key = adjustl(trim(trimmed_line(1:colon_pos-1)))
617 value_str = adjustl(trim(trimmed_line(colon_pos+1:)))
618 ! strip all quotes from value_str if present
619 if( &
620 ( &
621 value_str(1:1) .eq. '"' .and. &
622 value_str(len(trim(value_str)):len(trim(value_str))) .eq. '"' &
623 ) .or. ( &
624 value_str(1:1) .eq. '''' .and. &
625 value_str(len(trim(value_str)):len(trim(value_str))) .eq. '''' &
626 ) &
627 )then
628 value_str = value_str(2:len(trim(value_str))-1)
629 end if
630
631 select case(trim(key))
632 case('name')
633 attr%name = trim(value_str)
634 case('type')
635 if(attr%type .ne. '')then
636 write(0,*) "WARNING: Multiple 'type' entries in attribute. &
637 &Using the first one."
638 cycle
639 end if
640 attr_type_key = trim(value_str)
641 attr%type = to_lower(trim(attr_type_key))
642 case('ints', 'floats', 'strings', 'i', 'f', 's')
643 ! Accumulate multiple values with space separator
644 if(len_trim(value_buffer) .eq. 0) then
645 value_buffer = trim(value_str)
646 else
647 value_buffer = trim(value_buffer) // ' ' // trim(value_str)
648 end if
649 end select
650 end if
651 end do
652
653 ! Store accumulated values
654 if(len_trim(value_buffer) .gt. 0) then
655 attr%val = trim(value_buffer)
656 end if
657
658 end function read_attribute
659 !-------------------------------------------------------------------------------
660 function read_input_output(unit) result(tensor)
661 !! Reads an input or output block from an ONNX file
662 implicit none
663 integer, intent(in) :: unit
664 type(onnx_tensor_type) :: tensor
665
666 integer :: i
667 character(1024) :: line, trimmed_line
668 character(256) :: name
669 integer :: stat
670 integer :: num_open_braces, num_close_braces
671
672 ! Initialise tensor
673 tensor%elem_type = 0
674 allocate(tensor%dims(0))
675 num_open_braces = 0
676 num_close_braces = 0
677
678 do
679 read(unit, '(A)', iostat=stat) line
680 ! remove comments
681 if(index(trimmed_line, '#') .gt. 0) then
682 trimmed_line = trim(adjustl(trimmed_line(1:index(trimmed_line, '#')-1)))
683 end if
684 if(stat .ne. 0) exit
685
686 trimmed_line = adjustl(trim(line))
687 if(index(trimmed_line, 'name:') .gt. 0)then
688 call assign_val(trimmed_line, name, stat, fs=":")
689 elseif(index(trimmed_line, 'tensor_type {') .gt. 0)then
690 tensor = read_tensor_type(unit)
691 end if
692
693 ! count number of open { and close } to determine when shape block ends
694 do i = 1, len_trim(trimmed_line)
695 if (trimmed_line(i:i) .eq. '{') num_open_braces = num_open_braces + 1
696 if (trimmed_line(i:i) .eq. '}') num_close_braces = num_close_braces + 1
697 end do
698
699 ! Check for closing brace
700 if(num_close_braces .ge. num_open_braces .and. num_open_braces.gt.0)then
701 exit
702 end if
703
704 end do
705 tensor%name = trim(name)
706
707 end function read_input_output
708 !-------------------------------------------------------------------------------
709 function read_tensor_type(unit) result(tensor)
710 !! Reads the tensor type block from an ONNX file to extract dimensions
711 implicit none
712 integer, intent(in) :: unit
713 type(onnx_tensor_type) :: tensor
714
715 integer :: i
716 character(1024) :: line, trimmed_line, buffer
717 integer :: stat, dim_value
718 logical :: done, in_shape
719 integer :: num_open_braces, num_close_braces, shape_brace_idx
720
721 ! Initialise tensor
722 tensor%elem_type = 0
723 allocate(tensor%dims(0))
724
725 done = .false.
726 in_shape = .false.
727 num_open_braces = 0
728 num_close_braces = 0
729
730 do while(.not. done)
731 read(unit, '(A)', iostat=stat) line
732 if(stat .ne. 0) exit
733
734 trimmed_line = adjustl(trim(line))
735 ! remove comments
736 if(index(trimmed_line, '#') .gt. 0) then
737 trimmed_line = trim(adjustl(trimmed_line(1:index(trimmed_line, '#')-1)))
738 end if
739
740 if(index(trimmed_line, 'elem_type:') .gt. 0)then
741 call assign_val(trimmed_line, tensor%elem_type, stat, fs=":")
742 elseif(index(trimmed_line, 'shape {') .gt. 0)then
743 in_shape = .true.
744 shape_brace_idx = num_open_braces
745 buffer = trimmed_line(:index(trimmed_line, 'shape {') + 6)
746 do i = 1, len_trim(buffer)
747 if (buffer(i:i) .eq. '{') shape_brace_idx = shape_brace_idx + 1
748 end do
749 elseif(in_shape .and. index(trimmed_line, 'dim_value:') .gt. 0)then
750 call assign_val(trimmed_line, dim_value, stat, fs=":")
751 tensor%dims = [tensor%dims, dim_value]
752 end if
753
754 ! count number of open { and close } to determine when shape block ends
755 do i = 1, len_trim(trimmed_line)
756 if (trimmed_line(i:i) .eq. '{') num_open_braces = num_open_braces + 1
757 if (trimmed_line(i:i) .eq. '}') num_close_braces = num_close_braces + 1
758 end do
759
760 ! Check if we are still in shape block
761 if(in_shape .and. num_open_braces - num_close_braces .lt. shape_brace_idx) then
762 in_shape = .false.
763 end if
764
765 ! Check for closing brace
766 if(num_close_braces .ge. num_open_braces .and. num_open_braces.gt.0) then
767 done = .true.
768 end if
769 end do
770
771 end function read_tensor_type
772 !###############################################################################
773
774
775 !###############################################################################
776 subroutine write_onnx_tensor(unit, tensor_type, name, output_shape, batch_size)
777 !! Write ONNX value info for a layer
778 implicit none
779
780 ! Arguments
781 integer, intent(in) :: unit
782 !! File unit
783 character(*), intent(in) :: tensor_type
784 !! Type of the tensor
785 character(*), intent(in) :: name
786 !! Name of the layer
787 integer, intent(in), dimension(:) :: output_shape
788 !! Shape of the layer output
789 integer, intent(in) :: batch_size
790 !! Batch size for the output
791
792 ! Local variables
793 integer :: i
794 !! Loop index
795
796
797 write(unit, '(A,A,A)') ' ',tensor_type,' {'
798 write(unit, '(A,A,A)') ' name: "',name,'"'
799 write(unit, '(A)') ' type {'
800 write(unit, '(A)') ' tensor_type {'
801 write(unit, '(A)') ' elem_type: 1'
802 write(unit, '(A)') ' shape {'
803 write(unit, '(A,I0)') ' dim { dim_value: ', max(1,batch_size)
804 write(unit, '(A)') ' }'
805 do i = size(output_shape), 1, -1
806 write(unit, '(A,I0)') ' dim { dim_value: ', output_shape(i)
807 write(unit, '(A)') ' }'
808 end do
809 write(unit, '(A)') ' }'
810 write(unit, '(A)') ' }'
811 write(unit, '(A)') ' }'
812 write(unit, '(A)') ' }'
813
814 end subroutine write_onnx_tensor
815 !###############################################################################
816
817
818 !###############################################################################
819 subroutine write_onnx_initialisers(unit, layer, prefix)
820 !! Write ONNX initialisers (weights and biases)
821 implicit none
822
823 ! Arguments
824 integer, intent(in) :: unit
825 !! File unit
826 class(learnable_layer_type), intent(in) :: layer
827 !! Instance of a layer
828 character(*), intent(in) :: prefix
829 !! Optional prefix for weight and bias names
830
831 ! Local variables
832 integer :: i
833 !! Loop indices
834 integer :: num_params
835 !! Number of parameters
836 character(64) :: name
837 !! Names for parameters
838
839
840 if(allocated(layer%params))then
841 do i = 1, size(layer%params)
842 num_params = size(layer%params(i)%val,1)
843 write(name, '(A,A,I0)') trim(prefix), '_param', i
844 write(unit, '(2X,A)') 'initializer {'
845 write(unit, '(4X,"name: """,A,"""")') trim(name)
846 write(unit, '(4X,A)') 'data_type: 1' ! FLOAT
847 write(unit, '(4X,A,I0)') 'dims: ', num_params
848
849 write(unit, '(4X,"float_data: [ ")')
850 write(unit, '(20(F0.6,", "))') layer%params(i)%val(1:num_params-1,:)
851 write(unit, '(F0.6)') layer%params(i)%val(num_params,:)
852 write(unit, '(A)') ' ]'
853 write(unit, '(A)') ' }'
854 write(unit, '(A)') ''
855 end do
856 end if
857
858 end subroutine write_onnx_initialisers
859 !###############################################################################
860
861
862 !###############################################################################
863 subroutine write_onnx_function(unit, function_name, prefix)
864 !! Write ONNX function definition
865 implicit none
866
867 ! Arguments
868 integer, intent(in) :: unit
869 !! File unit
870 character(*), intent(in) :: function_name
871 !! Name of the function
872 character(*), intent(in) :: prefix
873 !! Optional prefix for the function name
874
875 ! Local variables
876 character(256) :: full_name
877 !! Full name of the function
878 character(:), allocatable :: function_name_camel_case
879 !! Camel case version of the function name
880
881 function_name_camel_case = &
882 to_camel_case(trim(adjustl(function_name)), capitalise_first_letter = .true.)
883 if(prefix .eq. "")then
884 full_name = trim(adjustl(function_name))
885 else
886 full_name = trim(prefix) // "_" // trim(adjustl(function_name))
887 end if
888
889
890 write(unit, '(A)') ' node {'
891 write(unit, '(A,A,A)') ' name: "', trim(full_name), '"'
892 write(unit, '(A,A,A)') ' op_type: "', trim(function_name_camel_case), '"'
893 write(unit, '(A,A,A)') ' input: "', trim(prefix), '_output"'
894 write(unit, '(A,A,A)') ' output: "', trim(full_name), '_output"'
895 write(unit, '(A)') ' }'
896 write(unit, '(A)') ''
897
898 end subroutine write_onnx_function
899 !###############################################################################
900
901
902 !###############################################################################
903 subroutine write_onnx_attributes(unit, layer)
904 !! Write ONNX attributes for a layer
905 implicit none
906
907 ! Arguments
908 integer, intent(in) :: unit
909 !! File unit
910 class(base_layer_type), intent(in) :: layer
911 !! Instance of a layer
912
913 ! Local variables
914 integer :: i, j, itmp1
915 !! Loop index
916 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
917 character(:), allocatable :: type_lw, type_up
918 integer, allocatable, dimension(:) :: ivar_list
919 real(real32), allocatable, dimension(:) :: rvar_list
920
921
922 attributes = layer%get_attributes()
923 if(allocated(attributes).and. size(attributes) .gt. 0)then
924 do i = 1, size(attributes)
925 write(unit, '(4X,A)') 'attribute {'
926 write(unit, '(6X,"name: """,A,"""")') trim(attributes(i)%name)
927 ! determine whether the attribute is a list or a single value
928 type_lw = to_lower(trim(adjustl(attributes(i)%type)))
929 type_up = to_upper(trim(adjustl(attributes(i)%type)))
930 itmp1 = icount(attributes(i)%val)
931 select case(type_lw)
932 case('ints','int')
933 allocate(ivar_list(itmp1))
934 read(attributes(i)%val,*) ivar_list
935 do j = 1, size(ivar_list)
936 write(unit, '(6X,A,": ",I0)') type_lw, ivar_list(j)
937 end do
938 deallocate(ivar_list)
939 case('floats','float')
940 allocate(rvar_list(itmp1))
941 read(attributes(i)%val,*) rvar_list
942 do j = 1, size(rvar_list), 1
943 write(unit, '(6X,A,": ",F0.6)') type_lw, rvar_list(j)
944 end do
945 deallocate(rvar_list)
946 case('strings','string')
947 case default
948 write(unit, '(6X,A,": ",A)') trim(adjustl(attributes(i)%type)), &
949 trim(adjustl(attributes(i)%val))
950 end select
951 write(unit,'(6X,"type: ",A)') type_up
952 write(unit,'(4X,"}")')
953 end do
954 end if
955
956 end subroutine write_onnx_attributes
957 !###############################################################################
958
959
15/24
✗ Branch 0 (213→214) not taken.
✓ Branch 1 (213→215) taken 7 times.
✗ Branch 2 (215→216) not taken.
✓ Branch 3 (215→217) taken 7 times.
✓ Branch 4 (218→219) taken 3 times.
✓ Branch 5 (218→236) taken 4 times.
✗ Branch 6 (295→296) not taken.
✓ Branch 7 (295→297) taken 8 times.
✗ Branch 8 (297→298) not taken.
✓ Branch 9 (297→299) taken 8 times.
✓ Branch 10 (300→301) taken 4 times.
✓ Branch 11 (300→314) taken 4 times.
✗ Branch 12 (376→377) not taken.
✓ Branch 13 (376→378) taken 8 times.
✗ Branch 14 (378→379) not taken.
✓ Branch 15 (378→380) taken 8 times.
✓ Branch 16 (381→382) taken 4 times.
✓ Branch 17 (381→414) taken 4 times.
✗ Branch 18 (681→682) not taken.
✓ Branch 19 (681→683) taken 1 times.
✗ Branch 20 (683→684) not taken.
✓ Branch 21 (683→685) taken 1 times.
✓ Branch 22 (686→687) taken 1 times.
✗ Branch 23 (686→711) not taken.
24 end submodule athena__onnx_submodule
960