GCC Code Coverage Report


Directory: src/athena/
File: athena_network_sub.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 10 54 18.5%
Functions: 1 2 50.0%
Branches: 12 116 10.3%

Line Branch Exec Source
1 submodule(athena__network) athena__network_submodule
2 !! Submodule containing implementations for the network module
3 #ifdef _OPENMP
4 use omp_lib
5 #endif
6 use coreutils, only: stop_program, print_warning, to_lower
7 use athena__misc_ml, only: shuffle
8
9 use athena__accuracy, only: categorical_score, mae_score, mse_score, r2_score
10 use athena__base_layer, only: learnable_layer_type, merge_layer_type
11 #if defined(GFORTRAN)
12 use athena__container_layer, only: container_reduction
13 #endif
14
15 use athena__container_layer, only: &
16 list_of_layer_types, allocate_list_of_layer_types, &
17 list_of_onnx_layer_creators, allocate_list_of_onnx_layer_creators
18
19 ! Layer types
20 use athena__flatten_layer, only: flatten_layer_type
21 use athena__add_layer, only: add_layer_type
22 use athena__concat_layer, only: concat_layer_type
23 use athena__input_layer, only: input_layer_type
24 use athena__msgpass_layer, only: msgpass_layer_type
25 use athena__recurrent_layer, only: recurrent_layer_type
26
27 ! #ifdef _OPENMP
28 ! !$omp declare reduction( &
29 ! !$omp& network_reduction : network_type:omp_out%network_reduction(omp_in)) &
30 ! !$omp& initialiser(omp_priv = omp_orig)
31 ! #endif
32
33 contains
34
35 !###############################################################################
36 module subroutine network_reduction(this, source)
37 !! Procedure to add two networks together
38 implicit none
39
40 ! Arguments
41 class(network_type), intent(inout) :: this
42 !! Instance of network
43 type(network_type), intent(in) :: source
44 !! Instance of network to be added to this
45
46 ! Local variables
47 integer :: i
48 !! Loop index
49
50 this%metrics(1)%val = this%metrics(1)%val + source%metrics(1)%val
51 this%metrics(2)%val = this%metrics(2)%val + source%metrics(2)%val
52 do i=1,size(this%model)
53 select type(layer_this => this%model(i)%layer)
54 class is(learnable_layer_type)
55 select type(layer_source => source%model(i)%layer)
56 class is(learnable_layer_type)
57 call layer_this%reduce(layer_source)
58 end select
59 end select
60 end do
61
62 end subroutine network_reduction
63 !###############################################################################
64
65
66 !###############################################################################
67 module subroutine network_copy(this, source)
68 !! Procedure to copy a network
69 implicit none
70
71 ! Arguments
72 class(network_type), intent(inout) :: this
73 !! Instance of network
74 type(network_type), intent(in), target :: source
75 !! Instance of network to be copied
76
77 ! Local variables
78 integer :: i
79 !! Loop index
80
81
82 this%metrics = source%metrics
83 this%model = source%model
84 this%num_layers = source%num_layers
85 this%batch_size = source%batch_size
86 this%num_params = source%num_params
87 this%num_outputs = source%num_outputs
88 this%optimiser = source%optimiser
89 this%vertex_order = source%vertex_order
90 this%root_vertices = source%root_vertices
91 this%leaf_vertices = source%leaf_vertices
92 this%loss = source%loss
93 this%get_accuracy => source%get_accuracy
94 this%auto_graph = source%auto_graph
95
96 end subroutine network_copy
97 !###############################################################################
98
99
100 !##############################################################################!
101 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
102 !##############################################################################!
103
104
105 !###############################################################################
106 module subroutine build_vertex_order(this)
107 !! Generate the order of the layers in the network
108 !!
109 !! This module contains the subroutine to generate the order of the layers
110 !! in the network. The order is generated by depth first search (DFS) on the
111 !! graph of the network.
112 implicit none
113
114 ! Arguments
115 class(network_type), intent(inout) :: this
116 !! Instance of network
117
118 ! Local variables
119 integer :: i, order_index
120 !! Loop index
121 logical, dimension(this%auto_graph%num_vertices) :: visited
122 !! Array to store whether a vertex has been
123
124 visited = .false.
125 if(allocated(this%vertex_order)) deallocate(this%vertex_order)
126 allocate(this%vertex_order(this%auto_graph%num_vertices), source=0)
127
128 order_index = 0
129 do i = this%auto_graph%num_vertices, 1, -1
130 if(.not.visited(i)) call this%dfs( &
131 i, visited, this%vertex_order, order_index &
132 )
133 end do
134
135 end subroutine build_vertex_order
136 !###############################################################################
137
138
139 !###############################################################################
140 recursive module subroutine dfs( &
141 this, vertex_index, visited, order, order_index &
142 )
143 !! Depth first search algorithm
144 implicit none
145
146 ! Arguments
147 class(network_type), intent(in) :: this
148 !! Instance of network
149 integer, intent(in) :: vertex_index
150 !! Index of the vertex to start the search from
151 logical, dimension(this%auto_graph%num_vertices), intent(inout) :: visited
152 !! Array to store whether a vertex has been visited
153 integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order
154 !! Array to store the order of the vertices
155 integer, intent(inout) :: order_index
156 !! Index of the current vertex in the order array
157
158 ! Local variables
159 integer :: i
160 !! Loop index
161
162 visited(vertex_index) = .true.
163 do i = 1, this%auto_graph%num_vertices, 1
164 if(this%auto_graph%adjacency(i,vertex_index).ne.0)then
165 if(.not.visited(i)) call this%dfs(i, visited, order, order_index)
166 end if
167 end do
168 order_index = order_index + 1
169 order(order_index) = vertex_index
170
171 end subroutine dfs
172 !###############################################################################
173
174
175 !###############################################################################
176 module subroutine build_root_vertices(this)
177 !! Calculate the root vertices of the network
178 implicit none
179
180 ! Arguments
181 class(network_type), intent(inout) :: this
182 !! Instance of network
183
184 ! Local variables
185 integer :: i
186 !! Loop index
187
188 if(allocated(this%root_vertices)) deallocate(this%root_vertices)
189 allocate(this%root_vertices(0))
190 ! from = 1
191 do i = 1, this%auto_graph%num_vertices
192 if(all(this%auto_graph%adjacency(:,i).eq.0))then
193 this%root_vertices = [this%root_vertices, i]
194 ! to = from + this%model(i)layer%num_input_data - 1
195 ! this%root_bounds = [ this%root_bounds, reshape([from,to], [2,1]) ]
196 ! from = to + 1
197 end if
198 end do
199 end subroutine build_root_vertices
200 !###############################################################################
201
202
203 !###############################################################################
204 module subroutine build_leaf_vertices(this)
205 !! Calculate the output vertices of the network
206 implicit none
207
208 ! Arguments
209 class(network_type), intent(inout) :: this
210 !! Instance of network
211
212 ! Local variables
213 integer :: i
214 !! Loop index
215
216 if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices)
217 allocate(this%leaf_vertices(0))
218 do i = 1, this%auto_graph%num_vertices
219 if(all(this%auto_graph%adjacency(i,:).eq.0))then
220 this%leaf_vertices = [this%leaf_vertices, i]
221 end if
222 end do
223 end subroutine build_leaf_vertices
224 !###############################################################################
225
226
227
228
229
230 !##############################################################################!
231 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
232 !##############################################################################!
233
234
235 !###############################################################################
236 module subroutine print(this, file)
237 !! Print the network to a file
238 use coreutils, only: to_upper
239 use athena__io_utils, only: athena__version__
240 implicit none
241
242 ! Arguments
243 class(network_type), intent(in) :: this
244 !! Instance of network
245 character(*), intent(in) :: file
246 !! File to print the network to
247
248 ! Local variables
249 integer :: l, v, e, vertex_index, unit
250 !! Loop index
251 integer :: operator_in, operator_out
252 !! Operators for the layer
253 character(3) :: operator_str
254 !! String to store the operator
255 character(256) :: suffix, fmt
256 !! Suffix for the layer
257 integer, dimension(:), allocatable :: input_list, output_list
258
259 open(newunit=unit,file=file,status='replace')
260
261 write(unit,'("NETWORK_SETTINGS")')
262 write(unit,'(3X,"ATHENA_VERSION = ",A)') trim(adjustl(athena__version__))
263 if(allocated(this%name)) write(unit,'(3X,"NAME = ",A)') trim(adjustl(this%name))
264 write(unit,'(3X,"EPOCH = ",I0)') this%epoch
265 write(unit,'(3X,"BATCH_SIZE = ",I0)') this%batch_size
266 write(unit,'(3X,"ACCURACY = ",F0.9)') this%accuracy_val
267 write(unit,'(3X,"LOSS = ",F0.9)') this%loss_val
268 if(allocated(this%accuracy_method))then
269 write(unit,'(3X,"ACCURACY_METHOD = ",A)') trim(adjustl(this%accuracy_method))
270 end if
271 if(allocated(this%loss_method))then
272 write(unit,'(3X,"LOSS_METHOD = ",A)') trim(adjustl(this%loss_method))
273 end if
274 if(allocated(this%optimiser))then
275 write(unit,'(3X,"OPTIMISER: ",A)') trim(adjustl(this%optimiser%name))
276 call this%optimiser%print_to_unit(unit=unit)
277 write(unit,'(3X,"END OPTIMISER")')
278 end if
279 write(unit,'("END NETWORK_SETTINGS")')
280
281 do v = 1, size(this%vertex_order,dim=1), 1
282 l = this%vertex_order(v)
283 operator_in = -1
284 operator_out = -1
285 allocate(input_list(0), output_list(0))
286 do e = 1, this%auto_graph%num_edges
287 if(-this%auto_graph%edge(e)%index(2).eq.l)then
288 if(operator_in.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_in)then
289 write(0,*) "WARNING: multiple operators for layer ", l
290 write(0,*) " using operator ", this%auto_graph%edge(e)%id
291 end if
292 operator_in = this%auto_graph%edge(e)%id
293 vertex_index = &
294 findloc( this%vertex_order, this%auto_graph%edge(e)%index(1), 1 )
295 input_list = [ input_list, vertex_index ]
296 end if
297 if(this%auto_graph%edge(e)%index(1).eq.l)then
298 if(operator_out.gt.0.and.this%auto_graph%edge(e)%id.ne.operator_out)then
299 write(0,*) "WARNING: multiple operators for layer ", l
300 write(0,*) " using operator ", this%auto_graph%edge(e)%id
301 end if
302 operator_in = this%auto_graph%edge(e)%id
303 vertex_index = &
304 findloc( this%vertex_order, this%auto_graph%edge(e)%index(2), 1 )
305 output_list = [ output_list, vertex_index ]
306 end if
307 end do
308
309 suffix = ""
310 select case(operator_in)
311 case(1)
312 operator_str = " ||"
313 case(2)
314 operator_str = " +"
315 case(3)
316 operator_str = " *"
317 end select
318 ! get size of input_list and make the formatted string
319 if(size(input_list).eq.0)then
320 write(suffix,'(A," []")') trim(operator_str)
321 else
322 write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(input_list)
323 write(suffix,fmt) trim(suffix), operator_str, input_list
324 end if
325 ! select case(operator_out)
326 ! case(1)
327 ! operator_str = " ||"
328 ! case(2)
329 ! operator_str = " +"
330 ! case(3)
331 ! operator_str = " *"
332 ! end select
333 ! if(size(output_list).gt.0)then
334 ! write(fmt,'("(A,A,"" ["",",I0,"(1X,I0),"" ]"")")') size(output_list)
335 ! write(suffix,fmt) trim(suffix), operator_str, output_list
336 ! end if
337
338 write(unit,'(A,A)') to_upper(trim(this%model(l)%layer%name)), trim(suffix)
339 call this%model(l)%layer%print(unit=unit, print_header_footer=.false.)
340
341 write(unit,'("END ",A)') to_upper(trim(this%model(l)%layer%name))
342 deallocate(input_list, output_list)
343 end do
344 close(unit)
345
346 end subroutine print
347 !###############################################################################
348
349
350 !###############################################################################
351 module subroutine read(this, file)
352 !! Read the network from a file
353 use coreutils, only: icount
354 implicit none
355
356 ! Arguments
357 class(network_type), intent(inout) :: this
358 !! Instance of network
359 character(*), intent(in) :: file
360 !! File to read the network from
361
362 ! Local variables
363 integer :: i, unit, stat, itmp1
364 !! Loop index
365 integer, dimension(:), allocatable :: input_list, output_list
366 !!! List of input and output layers
367 character(256) :: buffer, err_msg, input_str, output_str
368 !! Buffer for reading lines from file
369 character(20) :: name
370 !! Name of the layer
371 character(2) :: operator_in, operator_out
372 !! Operator for the layer
373 integer :: layer_index
374 !! Index of the layer in the list of layer types
375
376
377 if(.not.allocated(list_of_layer_types))then
378 call allocate_list_of_layer_types()
379 end if
380
381 open(newunit=unit,file=file,action='read')
382 i = 0
383 card_loop: do
384 i = i + 1
385 read(unit,'(A)',iostat=stat) buffer
386 if(stat.lt.0)then
387 exit card_loop
388 elseif(stat.gt.0)then
389 call stop_program("error encountered in network read")
390 return
391 end if
392 if(trim(adjustl(buffer)).eq."") cycle card_loop
393
394 !! check if a tag line
395 if(scan(buffer,'=').ne.0)then
396 write(0,*) "WARNING: unexpected line in read file"
397 write(0,*) trim(buffer)
398 write(0,*) " skipping..."
399 cycle card_loop
400 end if
401
402 !! check for card
403 name = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
404 if(name.eq."NETWORK_SETTINGS")then
405 call this%read_network_settings(unit)
406 cycle card_loop
407 end if
408 buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
409 operator_in = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
410 buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
411 input_str = trim(adjustl(buffer(1:scan(buffer,']'))))
412 if(scan(input_str,'[').ne.0)then
413 input_str = &
414 trim(adjustl(input_str(scan(input_str,'[')+1:scan(input_str,']')-1)))
415 itmp1 = icount(input_str)
416 allocate(input_list(itmp1))
417 read(input_str,*) input_list
418 else
419 allocate(input_list, source = [-1])
420 end if
421 buffer = trim(adjustl(buffer(scan(buffer,']')+1:)))
422 operator_out = trim(adjustl(buffer(1:scan(buffer,' ')-1)))
423 buffer = trim(adjustl(buffer(scan(buffer,' ')+1:)))
424 output_str = trim(adjustl(buffer(1:scan(buffer,']'))))
425 if(scan(output_str,'[').ne.0)then
426 output_str = &
427 trim(adjustl(output_str(scan(output_str,'[')+1:scan(output_str,']')-1)))
428 itmp1 = icount(output_str)
429 allocate(output_list(itmp1))
430 read(output_str,*) output_list
431 else
432 allocate(output_list(0))
433 end if
434 name = trim(adjustl(to_lower(name)))
435 layer_index = &
436 findloc( &
437 [ list_of_layer_types(:)%name ], &
438 name, &
439 dim = 1 &
440 )
441 if(layer_index.eq.0)then
442 write(err_msg,'("unrecognised card ''",A)') trim(adjustl(buffer))
443 call stop_program(err_msg)
444 return
445 end if
446 call this%add( &
447 list_of_layer_types(layer_index)%read_ptr(unit), &
448 input_list = input_list, &
449 operator = operator_in &
450 )
451 if(allocated(input_list)) deallocate(input_list)
452 if(allocated(output_list)) deallocate(output_list)
453 end do card_loop
454 close(unit)
455
456 end subroutine read
457 !###############################################################################
458
459
460 !###############################################################################
461 module subroutine read_network_settings(this, unit)
462 !! Read the network settings from a file
463 use athena__tools_infile, only: assign_val, assign_vec
464 use coreutils, only: to_lower, to_upper, icount
465 implicit none
466
467 ! Arguments
468 class(network_type), intent(inout) :: this
469 !! Instance of network
470 integer, intent(in) :: unit
471 !! File unit
472
473 ! Local variables
474 integer :: stat
475 !! File status
476 integer :: itmp1
477 !! Temporary integer
478 character(20) :: accuracy_method, loss_method
479 !! Methods for accuracy and loss
480 character(256) :: buffer, tag, err_msg, name_
481 !! Buffer for reading lines, tag for identifying lines, error message
482
483
484 ! Loop over tags in layer card
485 !---------------------------------------------------------------------------
486 accuracy_method = ""
487 loss_method = ""
488 tag_loop: do
489
490 ! Check for end of file
491 !------------------------------------------------------------------------
492 read(unit,'(A)',iostat=stat) buffer
493 if(stat.ne.0)then
494 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
495 to_upper(this%name)
496 call stop_program(err_msg)
497 return
498 end if
499 if(trim(adjustl(buffer)).eq."") cycle tag_loop
500
501 ! Check for end of layer card
502 !------------------------------------------------------------------------
503 if(trim(adjustl(buffer)).eq."END NETWORK_SETTINGS")then
504 backspace(unit)
505 exit tag_loop
506 end if
507
508 tag=trim(adjustl(buffer))
509 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
510 if(scan(buffer,":").ne.0) tag=trim(tag(:scan(tag,":")-1))
511
512 ! Read parameters from save file
513 !------------------------------------------------------------------------
514 select case(trim(tag))
515 case("ATHENA_VERSION")
516 ! Ignore this tag, it is only for information
517 case("NAME")
518 call assign_val(buffer, name_, itmp1)
519 if(len(trim(adjustl(name_))) .gt. 0)then
520 this%name = trim(adjustl(name_))
521 end if
522 case("EPOCH")
523 call assign_val(buffer, this%epoch, itmp1)
524 case("BATCH_SIZE")
525 call assign_val(buffer, this%batch_size, itmp1)
526 case("ACCURACY")
527 call assign_val(buffer, this%accuracy_val, itmp1)
528 case("LOSS")
529 call assign_val(buffer, this%loss_val, itmp1)
530 case("ACCURACY_METHOD")
531 call assign_val(buffer, accuracy_method, itmp1)
532 call this%set_accuracy(accuracy_method)
533 case("LOSS_METHOD")
534 call assign_val(buffer, loss_method, itmp1)
535 call this%set_loss(loss_method)
536 case("OPTIMISER")
537 backspace(unit)
538 call this%read_optimiser_settings(unit)
539 case default
540 ! Don't look for "e" due to scientific notation of numbers
541 ! ... i.e. exponent (E+00)
542 if(scan(to_lower(trim(adjustl(buffer))),&
543 'abcdfghijklmnopqrstuvwxyz').eq.0)then
544 cycle tag_loop
545 elseif(tag(:3).eq.'END')then
546 cycle tag_loop
547 end if
548 write(err_msg,'("Unrecognised line in input file: ",A)') &
549 trim(adjustl(buffer))
550 call stop_program(err_msg)
551 return
552 end select
553 end do tag_loop
554
555
556 ! Check for end of layer card
557 !---------------------------------------------------------------------------
558 read(unit,'(A)') buffer
559 if(trim(adjustl(buffer)).ne."END NETWORK_SETTINGS")then
560 write(0,*) trim(adjustl(buffer))
561 write(err_msg,'("END NETWORK_SETTINGS not where expected")')
562 call stop_program(err_msg)
563 return
564 end if
565
566 end subroutine read_network_settings
567 !-------------------------------------------------------------------------------
568 module subroutine read_optimiser_settings(this, unit)
569 !! Read the optimiser settings from a file
570 use coreutils, only: to_lower, to_upper, icount
571 use athena__optimiser, only: &
572 sgd_optimiser_type, adam_optimiser_type, rmsprop_optimiser_type, &
573 adagrad_optimiser_type, base_optimiser_type
574 implicit none
575
576 ! Arguments
577 class(network_type), intent(inout) :: this
578 !! Instance of network
579 integer, intent(in) :: unit
580 !! File unit
581
582 ! Local variables
583 integer :: stat
584 !! File status
585 character(20) :: optimiser_name
586 !! Name of the optimiser
587 character(256) :: buffer, err_msg, tmp
588 !! Buffer for reading lines, error message
589
590 ! Read until end of optimiser settings
591 read(unit,'(A)',iostat=stat) buffer
592 if(stat.ne.0)then
593 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
594 to_upper(this%name)
595 call stop_program(err_msg)
596 return
597 end if
598 read(buffer,*) tmp, optimiser_name
599
600 select case(trim(adjustl(to_lower(optimiser_name))))
601 case("sgd")
602 this%optimiser = sgd_optimiser_type()
603 case("adam")
604 this%optimiser = adam_optimiser_type()
605 case("rmsprop")
606 this%optimiser = rmsprop_optimiser_type()
607 case("adagrad")
608 this%optimiser = adagrad_optimiser_type()
609 case("","base")
610 this%optimiser = base_optimiser_type()
611 case default
612 write(err_msg,'("Unrecognised optimiser: ",A)') trim(adjustl(optimiser_name))
613 call stop_program(err_msg)
614 return
615 end select
616 call this%optimiser%read(unit)
617
618 end subroutine read_optimiser_settings
619 !###############################################################################
620
621
622 !###############################################################################
623 module subroutine build_from_onnx( &
624 this, nodes, initialisers, inputs, value_info, verbose &
625 )
626 !! Build network from ONNX nodes and initialisers
627 use coreutils, only: to_lower
628 implicit none
629
630 ! Arguments
631 class(network_type), intent(inout) :: this
632 !! Instance of network
633 type(onnx_node_type), dimension(:), intent(in) :: nodes
634 !! Array of ONNX nodes
635 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
636 !! Array of ONNX initialisers
637 type(onnx_tensor_type), dimension(:), intent(in) :: inputs
638 !! Array of ONNX inputs
639 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
640 !! Array of ONNX value infos
641 integer, optional, intent(in) :: verbose
642 !! Verbosity level
643
644 ! Local variables
645 integer :: i, j, k, j_out, layer_index
646 !! Loop indices
647 integer :: verbose_ = 0
648 !! Verbosity level
649 character(20) :: op_type
650 !! Lowercase op_type
651 character(64) :: tmp_name
652 !! Temporary name for matching
653 character(256) :: err_msg
654 !! Error message
655 integer, dimension(:), allocatable :: input_shape
656 !! Shape of input layer
657 integer, dimension(:), allocatable :: input_list
658 !! List of input layers
659 type(onnx_initialiser_type), dimension(:), allocatable :: init_list
660 !! List of initialisers for a specific node
661 type(onnx_tensor_type), dimension(:), allocatable :: value_info_list
662 !! List of value info tensors
663
664 verbose_ = 0
665 if(present(verbose)) verbose_ = verbose
666
667
668 if(.not.allocated(list_of_onnx_layer_creators))then
669 call allocate_list_of_onnx_layer_creators()
670 end if
671
672 do i = 1, size(inputs)
673 write(*,*) "Processing ONNX input: ", trim(inputs(i)%name)
674 input_shape = inputs(i)%dims(size(inputs(i)%dims):2:-1)
675
676 call this%add( &
677 input_layer_type(input_shape, index=i) &
678 )
679 end do
680
681 ! Loop through nodes and create layers
682 do i = 1, size(nodes)
683 if(verbose_.gt.0) write(*,*) "Processing ONNX node: ", trim(nodes(i)%name), &
684 " (", trim(nodes(i)%op_type), ")"
685 op_type = trim(adjustl(nodes(i)%op_type))
686
687 layer_index = &
688 findloc( &
689 [ list_of_onnx_layer_creators(:)%op_type ], &
690 op_type, &
691 dim = 1 &
692 )
693 if(layer_index.eq.0)then
694 write(err_msg,'("unrecognised op_type ''",A)') trim(adjustl(nodes(i)%op_type))
695 call stop_program(err_msg)
696 return
697 end if
698
699 ! find all input layers and initialisers for this node
700 ! ... i.e. check over inputs for name matches
701 j_out = 0
702 allocate(init_list(0))
703 allocate(input_list(0))
704 allocate(value_info_list(0))
705 do j = 1, size(nodes(i)%inputs)
706 do k = 1, size(initialisers)
707 if(trim(nodes(i)%inputs(j)) .eq. trim(initialisers(k)%name))then
708 init_list = [ init_list, initialisers(k) ]
709 end if
710 end do
711 do k = 1, size(inputs)
712 if(trim(nodes(i)%inputs(j)) .eq. trim(inputs(k)%name))then
713 input_list = [ input_list, k ]
714 end if
715 end do
716 tmp_name = trim(nodes(i)%inputs(j))
717 if(index(tmp_name, "_output").gt.0) &
718 tmp_name = trim(tmp_name(:index(tmp_name, "_output")-1))
719 do k = 1, size(nodes)
720 if(trim(tmp_name) .eq. trim(nodes(k)%name))then
721 input_list = [ input_list, k + size(inputs) ]
722 end if
723 end do
724 end do
725 do j = 1, size(nodes(i)%outputs)
726 do k = 1, size(value_info)
727 if(trim(nodes(i)%outputs(j)) .eq. trim(value_info(k)%name))then
728 value_info_list = [ value_info_list, value_info(k) ]
729 end if
730 end do
731 end do
732 if(size(init_list)+size(input_list).ne.size(nodes(i)%inputs))then
733 if(verbose_.gt.0)then
734 write(0,*) "WARNING: not all inputs found for node ", &
735 trim(nodes(i)%name)
736 end if
737 end if
738
739 ! assume default operator
740
741 call this%add( &
742 list_of_onnx_layer_creators(layer_index)%create_ptr( &
743 nodes(i), init_list, value_info_list &
744 ), &
745 input_list = input_list &
746 ! operator = operator_in &
747 )
748 deallocate(input_list)
749 deallocate(init_list)
750 deallocate(value_info_list)
751 end do
752
753 if(verbose_.gt.0) write(*,*) "ONNX model built with ", this%num_layers, " layers."
754
755 end subroutine build_from_onnx
756 !###############################################################################
757
758
759 !##############################################################################!
760 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
761 !##############################################################################!
762
763
764 !###############################################################################
765 module subroutine add(this, layer, input_list, output_list, operator)
766 !! Add a layer to the network
767 implicit none
768
769 ! Arguments
770 class(network_type), intent(inout) :: this
771 !! Instance of network
772 class(base_layer_type), intent(in) :: layer
773 !! Layer to add to the network
774 integer, dimension(:), optional, intent(in) :: input_list
775 !! List of input layers
776 integer, dimension(:), optional, intent(in) :: output_list
777 !! List of output layers
778 class(*), optional, intent(in) :: operator
779 !! Operator to use to connect the layers
780
781 ! Local variables
782 integer :: i, vertex_index
783 !! Loop index
784 integer :: operator_
785 !! Operator to use to connect the layers
786 character(256) :: err_msg
787 !! Error message
788 integer, dimension(2) :: vertex_indices
789 !! Indices of the vertices to connect
790 type(container_layer_type), allocatable, dimension(:) :: model
791 !! Model to add the layer to
792
793
794
795 if(.not.allocated(this%model))then
796 this%model = [container_layer_type()]
797 this%num_layers = 1
798 else
799 allocate(model(size(this%model,dim=1)+1))
800 do i = 1, size(this%model,dim=1)
801 allocate(model(i)%layer, source=this%model(i)%layer)
802 end do
803 call move_alloc(model, this%model)
804 this%num_layers = this%num_layers + 1
805 end if
806 allocate(this%model(size(this%model,dim=1))%layer, source=layer)
807 this%model(size(this%model,dim=1))%layer%id = this%num_layers
808
809
810 operator_ = 1
811 if(present(operator))then
812 select type(operator)
813 type is(integer)
814 operator_ = operator
815 type is(character(*))
816 select case(trim(to_lower(operator)))
817 case("||", "concat", "concatenate", "append")
818 operator_ = 1
819 case("+", "add")
820 operator_ = 2
821 case("*", "x", "mul", "multiply")
822 operator_ = 3
823 end select
824 end select
825 end if
826 if(operator_.gt.2.or.operator_.lt.1)then
827 call stop_program("invalid operator")
828 return
829 end if
830
831 ! edge_index(1) = index of the previous layer
832 ! abs(edge_index(2)) = index of the current layer
833 ! the -ve sign of edge_index(2) indicates that the edge goes from the
834 ! previous layer to the current layer
835 ! i.e. forward pass flows from positive to negative
836 ! adjacency(i,:) is all of the layers that i feeds forward to
837 ! adjacency(:,i) is all of the layers that feed forward to i
838 ! (i.e. the backward pass)
839 this%auto_graph%directed = .true.
840 call this%auto_graph%add_vertex( &
841 feature=[1._real32], id=this%num_layers, update_adjacency=.true. &
842 )
843 if(present(input_list))then
844 do i = 1, size(input_list), 1
845 if(input_list(i).eq.0)then
846 vertex_index = 0
847 elseif( &
848 input_list(i).le.-this%auto_graph%num_vertices .or. &
849 input_list(i).gt.this%auto_graph%num_vertices &
850 )then
851 write(err_msg, &
852 '("input vertex index ",I0," out of range (",I0,":",I0,")")' &
853 ) &
854 input_list(i), &
855 -this%auto_graph%num_vertices +1, &
856 this%auto_graph%num_vertices
857 call stop_program(err_msg)
858 return
859 elseif(input_list(i).lt.0)then
860 vertex_index = this%auto_graph%num_vertices + input_list(i)
861 else
862 vertex_index = findloc( &
863 [this%auto_graph%vertex(:)%id], &
864 input_list(i), 1 &
865 )
866 end if
867 vertex_indices = [ vertex_index, -this%auto_graph%num_vertices ]
868 call this%auto_graph%add_edge( &
869 index = vertex_indices, &
870 feature = [ 1._real32 ], &
871 id = operator_, &
872 update_adjacency = .true. &
873 )
874 end do
875 elseif(trim(layer%type).ne."inpt".and.this%auto_graph%num_vertices.gt.1)then
876 vertex_indices = [ &
877 this%auto_graph%num_vertices - 1, &
878 -this%auto_graph%num_vertices &
879 ]
880 call this%auto_graph%add_edge( &
881 index = vertex_indices, &
882 feature = [ 1._real32 ], &
883 id = operator_, &
884 update_adjacency = .true. &
885 )
886 end if
887
888 if(present(output_list))then
889 do i = 1, size(output_list), 1
890 vertex_index = findloc( &
891 [this%auto_graph%vertex(:)%id], &
892 output_list(i), 1 &
893 )
894 vertex_indices = [ this%auto_graph%num_vertices, -vertex_index ]
895 call this%auto_graph%add_edge( &
896 index = vertex_indices, &
897 feature = [ 1._real32 ], &
898 id = operator_, &
899 update_adjacency = .true. &
900 )
901 end do
902 end if
903
904 end subroutine add
905 !###############################################################################
906
907
908 !###############################################################################
909 module function network_setup( &
910 layers, optimiser, loss_method, accuracy_method, &
911 metrics, batch_size &
912 ) result(network)
913 !! Setup the network
914 implicit none
915
916 ! Arguments
917 type(container_layer_type), dimension(:), intent(in) :: layers
918 !! Layers to add to the network
919 class(base_optimiser_type), optional, intent(in) :: optimiser
920 !! Optimiser to use for training
921 class(*), optional, intent(in) :: loss_method
922 !! Loss method
923 character(*), optional, intent(in) :: accuracy_method
924 !! Accuracy method
925 class(*), dimension(..), optional, intent(in) :: metrics
926 !! Metrics
927 integer, optional, intent(in) :: batch_size
928 !! Batch size
929
930 type(network_type) :: network
931 !! Network to setup
932
933 ! Local variables
934 integer :: l
935 !! Loop index
936
937
938 !---------------------------------------------------------------------------
939 ! Handle optional arguments
940 !---------------------------------------------------------------------------
941 if(present(loss_method)) call network%set_loss(loss_method)
942 if(present(accuracy_method)) call network%set_accuracy(accuracy_method)
943 if(present(metrics)) call network%set_metrics(metrics)
944 if(present(batch_size)) network%batch_size = batch_size
945 network%auto_graph%directed = .true.
946
947
948 !---------------------------------------------------------------------------
949 ! Add layers to network
950 !---------------------------------------------------------------------------
951 do l = 1, size(layers)
952 call network%add(layers(l)%layer)
953 end do
954
955
956 !---------------------------------------------------------------------------
957 ! Compile network if optimiser present
958 !---------------------------------------------------------------------------
959 if(present(optimiser)) call network%compile(optimiser)
960
961 end function network_setup
962 !###############################################################################
963
964
965 !###############################################################################
966 module subroutine set_metrics(this, metrics)
967 !! Set the metrics for the network
968 use coreutils, only: to_lower
969 implicit none
970
971 ! Arguments
972 class(network_type), intent(inout) :: this
973 !! Instance of network
974 class(*), dimension(..), intent(in) :: metrics
975 !! Metrics
976
977 ! Local variables
978 integer :: i
979 !! Loop index
980
981
982 this%metrics%active = .false.
983 this%metrics(1)%key = "loss"
984 this%metrics(2)%key = "accuracy"
985 this%metrics%threshold = 1.E-1_real32
986 select rank(metrics)
987 #if defined(GFORTRAN)
988 rank(0)
989 select type(metrics)
990 type is(character(*))
991 ! ERROR: ifort cannot identify that the rank of metrics has been ...
992 ! ... identified as scalar here
993 where(to_lower(trim(metrics)).eq.this%metrics%key)
994 this%metrics%active = .true.
995 end where
996 end select
997 #endif
998 rank(1)
999 select type(metrics)
1000 type is(character(*))
1001 do i=1,size(metrics,1)
1002 where(to_lower(trim(metrics(i))).eq.this%metrics%key)
1003 this%metrics%active = .true.
1004 end where
1005 end do
1006 type is(metric_dict_type)
1007 if(size(metrics,1).eq.2)then
1008 this%metrics(:2) = metrics(:2)
1009 else
1010 call stop_program("invalid length array for metric_dict_type")
1011 return
1012 end if
1013 end select
1014 rank default
1015 call stop_program("provided metrics rank in compile invalid")
1016 return
1017 end select
1018
1019 end subroutine set_metrics
1020 !###############################################################################
1021
1022
1023 !###############################################################################
1024 module subroutine set_loss(this, loss_method, verbose)
1025 !! Set the loss method for the network
1026 use coreutils, only: to_lower
1027 use athena__loss, only: &
1028 bce_loss_type, &
1029 cce_loss_type, &
1030 mae_loss_type, &
1031 mse_loss_type, &
1032 nll_loss_type, &
1033 huber_loss_type
1034 implicit none
1035
1036 ! Arguments
1037 class(network_type), intent(inout) :: this
1038 !! Instance of network
1039 class(*), intent(in) :: loss_method
1040 !! Loss method
1041 integer, optional, intent(in) :: verbose
1042 !! Verbosity level
1043
1044 ! Local variables
1045 integer :: verbose_
1046 !! Verbosity level
1047 character(len=:), allocatable :: loss_method_
1048 !! Loss method
1049 character(256) :: err_msg
1050 !! Error message
1051
1052
1053 if(present(verbose))then
1054 verbose_ = verbose
1055 else
1056 verbose_ = 0
1057 end if
1058
1059 !---------------------------------------------------------------------------
1060 ! Handle analogous definitions
1061 !---------------------------------------------------------------------------
1062
1063 !---------------------------------------------------------------------------
1064 ! Set loss method
1065 !---------------------------------------------------------------------------
1066 select type(loss_method)
1067 class is(base_loss_type)
1068 this%loss = loss_method
1069 if(verbose_.gt.0) write(*,*) "Loss method: ", trim(loss_method%name)
1070 loss_method_ = trim(loss_method%name)
1071 type is(character(*))
1072 loss_method_ = to_lower(loss_method)
1073 select case(loss_method)
1074 case("binary_crossentropy")
1075 loss_method_ = "bce"
1076 case("categorical_crossentropy")
1077 loss_method_ = "cce"
1078 case("mean_absolute_error")
1079 loss_method_ = "mae"
1080 case("mean_squared_error")
1081 loss_method_ = "mse"
1082 case("negative_log_likelihood")
1083 loss_method_ = "nll"
1084 case("huber")
1085 loss_method_ = "hub"
1086 end select
1087 select case(loss_method_)
1088 case("bce")
1089 this%loss = bce_loss_type()
1090 if(verbose_.gt.0) write(*,*) "Loss method: Binary Cross Entropy"
1091 case("cce")
1092 this%loss = cce_loss_type()
1093 if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy"
1094 case("mae")
1095 this%loss = mae_loss_type()
1096 if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error"
1097 case("mse")
1098 this%loss = mse_loss_type()
1099 if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error"
1100 case("nll")
1101 this%loss = nll_loss_type()
1102 if(verbose_.gt.0) write(*,*) "Loss method: Negative Log Likelihood"
1103 case("hub")
1104 this%loss = huber_loss_type()
1105 if(verbose_.gt.0) write(*,*) "Loss method: Huber"
1106 case default
1107 write(err_msg,'(A)') &
1108 "No loss method provided" // &
1109 achar(13) // achar(10) // &
1110 "Failed loss method: "//trim(loss_method_)
1111 call stop_program(trim(err_msg))
1112 return
1113 end select
1114 end select
1115 this%loss_method = loss_method_
1116
1117 end subroutine set_loss
1118 !###############################################################################
1119
1120
1121 !###############################################################################
1122 module subroutine set_accuracy(this, accuracy_method, verbose)
1123 !! Set the accuracy method for the network
1124 use coreutils, only: to_lower
1125 use athena__accuracy, only: &
1126 categorical_score, &
1127 mae_score, &
1128 mse_score, &
1129 rmse_score, &
1130 r2_score
1131 implicit none
1132
1133 ! Arguments
1134 class(network_type), intent(inout) :: this
1135 !! Instance of network
1136 character(*), intent(in) :: accuracy_method
1137 !! Accuracy method
1138 integer, optional, intent(in) :: verbose
1139 !! Verbosity level
1140
1141 ! Local variables
1142 integer :: verbose_
1143 !! Verbosity level
1144 character(len=:), allocatable :: accuracy_method_
1145 !! Accuracy method
1146 character(256) :: err_msg
1147 !! Error message
1148
1149
1150 if(present(verbose))then
1151 verbose_ = verbose
1152 else
1153 verbose_ = 0
1154 end if
1155
1156 !---------------------------------------------------------------------------
1157 ! Handle analogous definitions
1158 !---------------------------------------------------------------------------
1159 accuracy_method_ = to_lower(accuracy_method)
1160 select case(accuracy_method)
1161 case("categorical")
1162 accuracy_method_ = "cat"
1163 case("mean_absolute_error")
1164 accuracy_method_ = "mae"
1165 case("mean_squared_error")
1166 accuracy_method_ = "mse"
1167 case("root_mean_squared_error")
1168 accuracy_method_ = "rmse"
1169 case("r2", "r^2", "r squared")
1170 accuracy_method_ = "r2"
1171 end select
1172
1173 !---------------------------------------------------------------------------
1174 ! Set accuracy method
1175 !---------------------------------------------------------------------------
1176 select case(accuracy_method_)
1177 case("cat")
1178 this%get_accuracy => categorical_score
1179 if(verbose_.gt.0) write(*,*) "Accuracy method: Categorical "
1180 case("mae")
1181 this%get_accuracy => mae_score
1182 if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Absolute Error"
1183 case("mse")
1184 this%get_accuracy => mse_score
1185 if(verbose_.gt.0) write(*,*) "Accuracy method: Mean Squared Error"
1186 case("rmse")
1187 this%get_accuracy => rmse_score
1188 if(verbose_.gt.0) write(*,*) "Accuracy method: Root Mean Squared Error"
1189 case("r2")
1190 this%get_accuracy => r2_score
1191 if(verbose_.gt.0) write(*,*) "Accuracy method: R^2"
1192 case default
1193 write(err_msg,'(A)') &
1194 "No accuracy method provided" // &
1195 achar(13) // achar(10) // &
1196 "Failed accuracy method: "//trim(accuracy_method_)
1197 call stop_program(trim(err_msg))
1198 return
1199 end select
1200 this%accuracy_method = accuracy_method_
1201
1202 end subroutine set_accuracy
1203 !###############################################################################
1204
1205
1206 !###############################################################################
1207 module subroutine reset(this)
1208 !! Reset the network
1209 implicit none
1210
1211 ! Arguments
1212 class(network_type), intent(inout) :: this
1213 !! Instance of network
1214
1215 this%epoch = 0
1216 this%accuracy_val = 0._real32
1217 this%loss_val = huge(1._real32)
1218 this%batch_size = 0
1219 this%num_layers = 0
1220 this%num_outputs = 0
1221 if(allocated(this%optimiser)) deallocate(this%optimiser)
1222 call this%set_metrics(["loss"])
1223 if(allocated(this%model)) deallocate(this%model)
1224 if(allocated(this%loss)) deallocate(this%loss)
1225 this%get_accuracy => null()
1226
1227 if(allocated(this%vertex_order)) deallocate(this%vertex_order)
1228 if(allocated(this%leaf_vertices)) deallocate(this%leaf_vertices)
1229 if(allocated(this%root_vertices)) deallocate(this%root_vertices)
1230 this%auto_graph = graph_type(directed=.true.)
1231
1232 end subroutine reset
1233 !###############################################################################
1234
1235
1236 !###############################################################################
1237 module subroutine compile( &
1238 this, optimiser, loss_method, accuracy_method, &
1239 metrics, batch_size, verbose &
1240 )
1241 !! Compile the network
1242 implicit none
1243
1244 ! Arguments
1245 class(network_type), intent(inout) :: this
1246 !! Instance of network
1247 class(base_optimiser_type), optional, intent(in) :: optimiser
1248 !! Optimiser to use for training
1249 class(*), optional, intent(in) :: loss_method
1250 !! Loss method
1251 character(*), optional, intent(in) :: accuracy_method
1252 !! Accuracy method
1253 class(*), dimension(..), optional, intent(in) :: metrics
1254 !! Metrics
1255 integer, optional, intent(in) :: batch_size
1256 !! Batch size
1257 integer, optional, intent(in) :: verbose
1258 !! Verbosity level
1259
1260 ! Local variables
1261 integer :: i, j, k, child_id, parent_id, layer_id, num_inputs, input_rank
1262 !! Loop index
1263 integer :: parent_vertex, vertex_idx
1264 !! Vertex indices
1265 integer :: layer_rank, parent_rank, operator
1266 !! Ranks of layers
1267 integer :: verbose_ = 0
1268 !! Verbosity level
1269 logical :: use_graph_input = .false.
1270 !! Boolean whether to use graph input
1271 logical :: l_flatten_child, l_set_input_shape
1272 !! Booleans whether to flatten child or set input shape
1273 integer, dimension(:), allocatable :: input_shape, &
1274 child_vertices, parent_vertices, output_ranks, parent_ids
1275 !! Shapes of the input and output of the layers
1276 integer, dimension(:,:), allocatable :: merge_shape
1277 !! Shapes of the inputs to merge layers
1278 class(base_layer_type), allocatable :: &
1279 t_input_layer, t_flatten_layer, t_merge_layer
1280 !! Temporary input, flatten, and merge layers
1281
1282
1283 !---------------------------------------------------------------------------
1284 ! Initialise optional arguments
1285 !---------------------------------------------------------------------------
1286 if(present(verbose)) verbose_ = verbose
1287
1288
1289 !---------------------------------------------------------------------------
1290 ! Initialise metrics
1291 !---------------------------------------------------------------------------
1292 if(present(metrics)) call this%set_metrics(metrics)
1293
1294
1295 !---------------------------------------------------------------------------
1296 ! Initialise loss and accuracy methods
1297 !---------------------------------------------------------------------------
1298 if(present(loss_method)) call this%set_loss(loss_method, verbose_)
1299 if(present(accuracy_method)) &
1300 call this%set_accuracy(accuracy_method, verbose_)
1301
1302
1303 !---------------------------------------------------------------------------
1304 ! Check for input layers at root vertices
1305 !---------------------------------------------------------------------------
1306 this%auto_graph%directed = .true.
1307 call this%build_root_vertices()
1308 do i = 1, size(this%root_vertices)
1309 layer_id = this%auto_graph%vertex(this%root_vertices(i))%id
1310 if(.not.allocated(this%model(layer_id)%layer%input_shape))then
1311 call stop_program("input_shape of first layer not defined")
1312 return
1313 end if
1314 use_graph_input = .false.
1315 select type( root => this%model(layer_id)%layer)
1316 class is(input_layer_type)
1317 cycle
1318 class is(learnable_layer_type)
1319 input_shape = root%input_shape
1320 use_graph_input = root%use_graph_input
1321 class default
1322 input_shape = root%input_shape
1323 end select
1324 t_input_layer = input_layer_type(&
1325 input_shape = input_shape, &
1326 index = i, &
1327 use_graph_input = use_graph_input, &
1328 verbose=verbose_ &
1329 )
1330 call this%add( &
1331 t_input_layer, output_list = [ this%model(layer_id)%layer%id ] &
1332 )
1333 ! NEED TO CALL layer%init?
1334 deallocate(input_shape)
1335 deallocate(t_input_layer)
1336 this%root_vertices(i) = this%num_layers
1337 if(i.eq.1)then
1338 do j = 1, this%auto_graph%num_edges
1339 if(this%auto_graph%edge(j)%index(1).eq.0) &
1340 this%auto_graph%edge(j)%index(1) = this%num_layers
1341 end do
1342 end if
1343 end do
1344 call this%auto_graph%generate_adjacency()
1345
1346
1347 !---------------------------------------------------------------------------
1348 ! Identify whether input is graph type
1349 !---------------------------------------------------------------------------
1350 if( &
1351 this%model( &
1352 this%auto_graph%vertex(this%root_vertices(1))%id &
1353 )%layer%use_graph_input &
1354 )then
1355 this%use_graph_input = .true.
1356 else
1357 this%use_graph_input = .false.
1358 end if
1359
1360
1361 !---------------------------------------------------------------------------
1362 ! Check for zero input rank layers
1363 !---------------------------------------------------------------------------
1364 do i = 1, size(this%auto_graph%vertex, dim = 1)
1365 layer_id = this%auto_graph%vertex(i)%id
1366 if(this%model(layer_id)%layer%input_rank.eq.0)then
1367 parent_ids = pack( &
1368 [ ( &
1369 this%auto_graph%vertex(j)%id, &
1370 j = 1, size(this%auto_graph%adjacency(:,i)) &
1371 ) ], &
1372 this%auto_graph%adjacency(:,i) .ne. 0 &
1373 )
1374 if(size(parent_ids).eq.0) cycle
1375 output_ranks = [ ( this%model(parent_ids(j))%layer%output_rank, &
1376 j=1,size(parent_ids) ) ]
1377 if(any(output_ranks.ne.output_ranks(1)))then
1378 write(0,*) output_ranks
1379 call stop_program( &
1380 "input rank of layer "//trim(this%model(layer_id)%layer%name) // &
1381 " is zero, but multiple parents with different output ranks" &
1382 )
1383 return
1384 end if
1385 input_rank = this%model(parent_ids(1))%layer%output_rank
1386 call this%model(layer_id)%layer%set_rank( &
1387 input_rank = input_rank, &
1388 output_rank = input_rank &
1389 )
1390 end if
1391 end do
1392
1393
1394 !---------------------------------------------------------------------------
1395 ! Check for required flatten layers
1396 !---------------------------------------------------------------------------
1397 i = 0
1398 flatten_loop: do
1399 i = i + 1
1400 if(i.gt.this%auto_graph%num_vertices) exit flatten_loop
1401 layer_id = this%auto_graph%vertex(i)%id
1402
1403 ! get all child vertices
1404 child_vertices = pack( &
1405 [(j, j=1,size(this%auto_graph%adjacency(i,:)))], &
1406 this%auto_graph%adjacency(i,:) .ne. 0 &
1407 )
1408 child_loop: do j = 1, size(child_vertices)
1409 ! Get layer ID (needed for add() function's output_list parameter)
1410 child_id = this%auto_graph%vertex(child_vertices(j))%id
1411 if(trim(this%model(layer_id)%layer%type).eq."flat") cycle child_loop
1412 if( this%model(layer_id)%layer%output_rank .eq. &
1413 this%model(child_id)%layer%input_rank ) cycle child_loop
1414 if(this%model(layer_id)%layer%output_rank.eq.0) cycle child_loop
1415
1416 ! get all parent vertices of the child vertex
1417 parent_vertices = pack( &
1418 [(k, k=1,size(this%auto_graph%adjacency(:,child_vertices(j))))], &
1419 this%auto_graph%adjacency(:,child_vertices(j)) .ne. 0 &
1420 )
1421 l_flatten_child = .true.
1422 do k = 1, size(parent_vertices)
1423 parent_id = this%auto_graph%vertex(parent_vertices(k))%id
1424 !check if ranks match, rather than input and output shapes
1425 if( this%model(layer_id)%layer%output_rank .ne. &
1426 this%model(parent_id)%layer%input_rank &
1427 ) l_flatten_child = .false.
1428 end do
1429 t_flatten_layer = flatten_layer_type( &
1430 input_rank = this%model(layer_id)%layer%output_rank &
1431 )
1432
1433 if(l_flatten_child)then
1434 ! add flatten layer in the place of the child layer
1435 operator = this%auto_graph%edge( &
1436 this%auto_graph%adjacency(parent_vertices(1),child_vertices(j)) &
1437 )%id
1438 call this%auto_graph%remove_edges( &
1439 indices = [ &
1440 this%auto_graph%adjacency( &
1441 parent_vertices(:),child_vertices(j) &
1442 ) &
1443 ] &
1444 )
1445 call this%add( &
1446 t_flatten_layer, &
1447 input_list=[parent_vertices(:)], output_list=[child_id], &
1448 operator=operator &
1449 )
1450 else
1451 ! add flatten layer between the current layer and the child layer
1452 call this%auto_graph%remove_edges( &
1453 indices = [this%auto_graph%adjacency(i,child_vertices(j))] &
1454 )
1455 call this%add( &
1456 t_flatten_layer, input_list = [i], output_list = [child_id], &
1457 operator=operator &
1458 )
1459 end if
1460 deallocate(t_flatten_layer)
1461 deallocate(child_vertices)
1462 cycle flatten_loop
1463 end do child_loop
1464 deallocate(child_vertices)
1465 end do flatten_loop
1466 call this%build_vertex_order()
1467
1468
1469 !---------------------------------------------------------------------------
1470 ! Check for required merge layers
1471 !---------------------------------------------------------------------------
1472 i = 0
1473 merge_loop: do
1474 i = i + 1
1475 if(i.gt.this%auto_graph%num_vertices) exit merge_loop
1476 layer_id = this%auto_graph%vertex(i)%id
1477 if(this%model(layer_id)%layer%type.eq."merg") cycle merge_loop
1478
1479 ! get all child vertices
1480 parent_vertices = pack( &
1481 [(j, j=1,size(this%auto_graph%adjacency(:,i)))], &
1482 this%auto_graph%adjacency(:,i) .ne. 0 &
1483 )
1484 if(size(parent_vertices).le.1) cycle merge_loop
1485
1486 ! get edge id for merge layer
1487 operator = this%auto_graph%edge( &
1488 this%auto_graph%adjacency(parent_vertices(1),i) &
1489 )%id
1490
1491 ! remove edges from parents to this layer
1492 do j = 1, size(parent_vertices)
1493 call this%auto_graph%remove_edges( &
1494 indices = [this%auto_graph%adjacency(parent_vertices(j),i)] &
1495 )
1496 end do
1497 parent_ids = &
1498 [ ( &
1499 this%auto_graph%vertex(parent_vertices(k))%id, &
1500 k = 1, size(parent_vertices) &
1501 ) ]
1502 select case(operator)
1503 case(1) ! concatenate
1504 t_merge_layer = concat_layer_type( &
1505 input_layer_ids = parent_ids, &
1506 input_rank = this%model(layer_id)%layer%input_rank &
1507 )
1508 case(2) ! add
1509 t_merge_layer = add_layer_type( &
1510 input_layer_ids = parent_ids, &
1511 input_rank = this%model(layer_id)%layer%input_rank &
1512 )
1513 ! case(3) ! multiply
1514 ! t_merge_layer = multiply_layer_type( &
1515 ! input_layer_ids = parent_vertices &
1516 ! )
1517 case default
1518 write(0,*) "invalid merge operator: ", operator
1519 call stop_program("invalid merge operator")
1520 return
1521 end select
1522 t_merge_layer%use_graph_input = this%model(layer_id)%layer%use_graph_input
1523 t_merge_layer%use_graph_output = t_merge_layer%use_graph_input
1524 call this%add( &
1525 t_merge_layer, &
1526 input_list = parent_ids, &
1527 output_list = [layer_id] &
1528 )
1529 deallocate(t_merge_layer)
1530 end do merge_loop
1531 call this%build_vertex_order()
1532
1533
1534 ! Update number of layers
1535 !---------------------------------------------------------------------------
1536 this%num_layers = size(this%model,dim=1)
1537
1538
1539
1540 !---------------------------------------------------------------------------
1541 ! Initialise layers
1542 !---------------------------------------------------------------------------
1543 do i = 1, size(this%vertex_order, dim = 1)
1544 vertex_idx = this%vertex_order(i)
1545 layer_id = this%auto_graph%vertex(vertex_idx)%id
1546 if(allocated(this%model(layer_id)%layer%input_shape))then
1547 l_set_input_shape = .false.
1548 else
1549 l_set_input_shape = .true.
1550 end if
1551 if(l_set_input_shape)then
1552 layer_rank = this%model(layer_id)%layer%input_rank
1553 parent_rank = 0
1554
1555 select type( layer => this%model(layer_id)%layer )
1556 class is(merge_layer_type)
1557 ! loop over all parent layers
1558 allocate( &
1559 merge_shape( &
1560 this%model(layer_id)%layer%input_rank, &
1561 size(layer%input_layer_ids) &
1562 ) &
1563 )
1564 do k = 1, size(layer%input_layer_ids)
1565 merge_shape(:,k) = &
1566 this%model(layer%input_layer_ids(k))%layer%output_shape
1567 end do
1568 input_shape = layer%calc_input_shape(merge_shape)
1569 deallocate(merge_shape)
1570 class default
1571
1572 allocate( &
1573 input_shape(this%model(layer_id)%layer%input_rank), &
1574 source = 0 &
1575 )
1576 do j = 1, this%auto_graph%num_vertices
1577 if(this%auto_graph%adjacency(j,vertex_idx).eq.0) cycle
1578 parent_id = this%auto_graph%vertex(j)%id
1579 parent_rank = this%model(parent_id)%layer%output_rank
1580
1581 if(layer_rank .eq. parent_rank)then
1582 input_shape(:) = input_shape(:) + &
1583 this%model(parent_id)%layer%output_shape
1584 elseif(layer_rank .eq. 1)then
1585 input_shape(1) = input_shape(1) + &
1586 product( this%model(parent_id)%layer%output_shape )
1587 end if
1588 end do
1589 end select
1590 call this%model(layer_id)%layer%init( &
1591 input_shape = input_shape, &
1592 verbose = verbose_ &
1593 )
1594 deallocate(input_shape)
1595 end if
1596 if(verbose_.gt.0)then
1597 write(*,*) "layer: ", layer_id, this%model(layer_id)%layer%type
1598 write(*,*) this%model(layer_id)%layer%input_shape
1599 write(*,*) this%model(layer_id)%layer%output_shape
1600 end if
1601 end do
1602
1603
1604 ! Set number of outputs
1605 !---------------------------------------------------------------------------
1606 this%num_outputs = 0
1607 call this%build_leaf_vertices()
1608 do i = 1, size(this%leaf_vertices,1)
1609 this%num_outputs = this%num_outputs + &
1610 product( &
1611 this%model( &
1612 this%auto_graph%vertex(this%leaf_vertices(i))%id &
1613 )%layer%output_shape &
1614 )
1615 end do
1616 if( &
1617 this%model( &
1618 this%auto_graph%vertex(this%leaf_vertices(1))%id &
1619 )%layer%use_graph_output &
1620 )then
1621 this%use_graph_output = .true.
1622 else
1623 this%use_graph_output = .false.
1624 end if
1625
1626
1627 !---------------------------------------------------------------------------
1628 ! Confirm input_shape of each layer matches data going into it
1629 !---------------------------------------------------------------------------
1630 do i = 1, size(this%vertex_order, dim = 1)
1631 vertex_idx = this%vertex_order(i)
1632 layer_id = this%auto_graph%vertex(vertex_idx)%id
1633 if(this%model(layer_id)%layer%type.eq."inpt") cycle
1634
1635 ! Get all parent vertices that feed into this layer
1636 parent_vertices = pack( &
1637 [(j, j=1,size(this%auto_graph%adjacency(:,vertex_idx)))], &
1638 this%auto_graph%adjacency(:,vertex_idx) .ne. 0 &
1639 )
1640 if(size(parent_vertices).eq.0) cycle
1641 select type( layer => this%model(layer_id)%layer )
1642 class is(merge_layer_type)
1643 operator = layer%merge_mode
1644 class default
1645 if(size(parent_vertices).gt.1)then
1646 call stop_program( &
1647 "layer "//trim(layer%name)// &
1648 " is not a merge layer but has multiple inputs" &
1649 )
1650 return
1651 end if
1652 end select
1653
1654 ! Calculate expected input size from parent layers
1655 num_inputs = 0
1656 do j = 1, size(parent_vertices)
1657 parent_vertex = parent_vertices(j)
1658
1659 select case(operator)
1660 case(1) ! pointwise - all inputs should have same size
1661 if(num_inputs.eq.0)then
1662 if(this%model(layer_id)%layer%use_graph_input)then
1663 num_inputs = this%model(parent_vertex)%layer%output_shape(1)
1664 else
1665 num_inputs = product(this%model(parent_vertex)%layer%output_shape)
1666 end if
1667 end if
1668 case(2) ! concatenate
1669 if(this%model(layer_id)%layer%use_graph_input)then
1670 num_inputs = num_inputs + &
1671 this%model(parent_vertex)%layer%output_shape(1)
1672 else
1673 num_inputs = num_inputs + &
1674 product(this%model(parent_vertex)%layer%output_shape)
1675 end if
1676 end select
1677 end do
1678
1679 ! Verify calculated input size matches layer's expected input size
1680 if(this%model(layer_id)%layer%use_graph_input)then
1681 if(num_inputs.ne.this%model(layer_id)%layer%input_shape(1) .and. &
1682 num_inputs.ne.0)then
1683 write(*,*) "Expected:", num_inputs, "Got:", &
1684 this%model(layer_id)%layer%input_shape(1)
1685 call stop_program( &
1686 "input_shape of layer "//&
1687 trim(this%model(layer_id)%layer%name)// &
1688 " does not match data going into it" &
1689 )
1690 end if
1691 else
1692 if(num_inputs.ne.product(this%model(layer_id)%layer%input_shape) .and. &
1693 num_inputs.ne.0)then
1694 write(*,*) "Expected:", num_inputs, "Got:", &
1695 product(this%model(layer_id)%layer%input_shape)
1696 call stop_program( &
1697 "input_shape of layer "//&
1698 trim(this%model(layer_id)%layer%name)// &
1699 " does not match data going into it" &
1700 )
1701 end if
1702 end if
1703
1704 end do
1705
1706 !---------------------------------------------------------------------------
1707 ! Initialise optimiser
1708 !---------------------------------------------------------------------------
1709 this%num_params = this%get_num_params()
1710 if(present(optimiser))then
1711 this%optimiser = optimiser
1712 end if
1713 if(.not.allocated(this%optimiser))then
1714 call stop_program("No optimiser is defined for the network")
1715 return
1716 else
1717 call this%optimiser%init(num_params=this%num_params)
1718 end if
1719
1720
1721 !---------------------------------------------------------------------------
1722 ! Set batch size, if provided
1723 !---------------------------------------------------------------------------
1724 if(present(batch_size)) this%batch_size = batch_size
1725
1726 end subroutine compile
1727 !###############################################################################
1728
1729
1730 !###############################################################################
1731 module subroutine set_batch_size(this, batch_size)
1732 !! Set the batch size for the network
1733 implicit none
1734
1735 ! Arguments
1736 class(network_type), intent(inout) :: this
1737 !! Instance of network
1738 integer, intent(in) :: batch_size
1739 !! Batch size
1740
1741 ! Local variables
1742 integer :: l
1743 !! Loop index
1744
1745
1746 this%batch_size = batch_size
1747
1748 end subroutine set_batch_size
1749 !###############################################################################
1750
1751
1752 !###############################################################################
1753 module subroutine reset_state(this)
1754 !! Reset the hidden state of all layers in the network
1755 implicit none
1756
1757 ! Arguments
1758 class(network_type), intent(inout) :: this
1759 !! Instance of network
1760
1761 ! Local variables
1762 integer :: l
1763 !! Loop index
1764
1765 do l = 1, size(this%model, dim = 1)
1766 select type( layer => this%model(l)%layer )
1767 class is(recurrent_layer_type)
1768 call layer%reset_state()
1769 end select
1770 end do
1771
1772 end subroutine reset_state
1773 !###############################################################################
1774
1775
1776 !###############################################################################
1777 module function layer_from_id(this, id) result(layer)
1778 !! Get layer from its ID
1779 implicit none
1780
1781 ! Arguments
1782 class(network_type), intent(in), target :: this
1783 !! Instance of network
1784 integer, intent(in) :: id
1785 !! Layer ID
1786
1787 class(base_layer_type), pointer :: layer
1788 !! Layer
1789
1790 ! Local variables
1791 integer :: i, itmp1
1792 !! Loop index
1793
1794 itmp1 = 0
1795 do i = 1, size(this%model, dim = 1)
1796 if(this%model(i)%layer%id.eq.id)then
1797 if(itmp1.ne.0)then
1798 call stop_program("multiple layers with same ID found")
1799 return
1800 end if
1801 layer => this%model(i)%layer
1802 itmp1 = itmp1 + 1
1803 end if
1804 end do
1805
1806 end function layer_from_id
1807 !###############################################################################
1808
1809
1810 !##############################################################################!
1811 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
1812 !##############################################################################!
1813
1814
1815 !###############################################################################
1816 module function get_sample_ptr( &
1817 input, start_index, end_index, batch_size &
1818 ) result(sample_ptr)
1819 !! Get samples of batch size from a real array
1820 implicit none
1821
1822 ! Arguments
1823 integer, intent(in) :: start_index, end_index
1824 !! Start and end indices
1825 integer, intent(in) :: batch_size
1826 !! Batch size
1827 real(real32), dimension(..), intent(in), target :: input
1828 !! Input array
1829
1830 real(real32), pointer :: sample_ptr(:,:)
1831 !! Pointer to sample
1832
1833
1834 select rank(input)
1835 rank(2)
1836 sample_ptr(1:size(input(:,1)),1:end_index-start_index+1) => &
1837 input(:,start_index:end_index)
1838 rank(3)
1839 sample_ptr(1:size(input(:,:,1)),1:end_index-start_index+1) => &
1840 input(:,:,start_index:end_index)
1841 rank(4)
1842 sample_ptr(1:size(input(:,:,:,1)),1:end_index-start_index+1) => &
1843 input(:,:,:,start_index:end_index)
1844 rank(5)
1845 sample_ptr(1:size(input(:,:,:,:,1)),1:end_index-start_index+1) => &
1846 input(:,:,:,:,start_index:end_index)
1847 rank(6)
1848 sample_ptr(1:size(input(:,:,:,:,:,1)),1:end_index-start_index+1) => &
1849 input(:,:,:,:,:,start_index:end_index)
1850 rank default
1851 sample_ptr => null()
1852 end select
1853
1854 end function get_sample_ptr
1855 !-------------------------------------------------------------------------------
1856 module function get_sample_array( &
1857 input, start_index, end_index, batch_size, as_graph &
1858 ) result(sample)
1859 !! Get samples of batch size from a derived type array
1860 implicit none
1861
1862 ! Arguments
1863 integer, intent(in) :: start_index, end_index
1864 !! Start and end indices
1865 integer, intent(in) :: batch_size
1866 !! Batch size
1867 class(array_type), dimension(:,:), intent(in) :: input
1868 !! Input array
1869 logical, intent(in) :: as_graph
1870 !! Boolean whether to treat the input as a graph
1871
1872 type(array_type), dimension(:,:), allocatable :: sample
1873 !! Sample array
1874
1875 ! Local variables
1876 integer :: i, j
1877 !! Loop index
1878
1879 if(as_graph)then
1880 allocate(sample(size(input,1), batch_size))
1881 do i = 1, size(input,1)
1882 do j = start_index, end_index, 1
1883 sample(i, j - start_index + 1)%val = input(i, j)%val
1884 end do
1885 end do
1886 else
1887 allocate(sample(size(input,1), size(input,2)))
1888 do i = 1, size(input,1)
1889 do j = 1, size(input,2)
1890 call sample(i,j)%allocate(array_shape=[input(i,j)%shape, &
1891 end_index - start_index + 1])
1892 sample(i,j)%val = get_sample_ptr( &
1893 input(i,j)%val, start_index, end_index, batch_size &
1894 )
1895 end do
1896 end do
1897 end if
1898
1899 end function get_sample_array
1900 !-------------------------------------------------------------------------------
1901 module function get_sample_graph1d( &
1902 input, start_index, end_index, batch_size &
1903 ) result(sample)
1904 !! Get samples of batch size from a graph
1905 implicit none
1906
1907 ! Arguments
1908 integer, intent(in) :: start_index, end_index
1909 !! Start and end indices
1910 integer, intent(in) :: batch_size
1911 !! Batch size
1912 class(graph_type), dimension(:), intent(in) :: input
1913 !! Input array
1914
1915 type(graph_type), dimension(1, batch_size) :: sample
1916 !! Sample array
1917
1918 sample(1,1:batch_size) = input(start_index:end_index)
1919
1920 end function get_sample_graph1d
1921 !-------------------------------------------------------------------------------
1922 module function get_sample_graph2d( &
1923 input, start_index, end_index, batch_size &
1924 ) result(sample)
1925 !! Get samples of batch size from a graph
1926 implicit none
1927
1928 ! Arguments
1929 integer, intent(in) :: start_index, end_index
1930 !! Start and end indices
1931 integer, intent(in) :: batch_size
1932 !! Batch size
1933 class(graph_type), dimension(:,:), intent(in) :: input
1934 !! Input array
1935
1936 type(graph_type), dimension(size(input,1), batch_size) :: sample
1937 !! Sample array
1938
1939 sample(1:size(input,1),1:batch_size) = input(:,start_index:end_index)
1940
1941 end function get_sample_graph2d
1942 !###############################################################################
1943
1944
1945 !###############################################################################
1946 pure module function get_num_params(this) result(num_params)
1947 !! Get the number of learnable parameters in the network
1948 implicit none
1949
1950 ! Arguments
1951 class(network_type), intent(in) :: this
1952 !! Instance of network
1953 integer :: num_params
1954 !! Number of parameters
1955
1956 ! Local variables
1957 integer :: l, i
1958 !! Loop index
1959
1960 num_params = 0
1961 do l = 1, this%num_layers
1962 select type(current => this%model(l)%layer)
1963 class is(learnable_layer_type)
1964 do i = 1, size(current%params)
1965 num_params = num_params + size(current%params(i)%val, 1)
1966 end do
1967 end select
1968 end do
1969
1970 end function get_num_params
1971 !###############################################################################
1972
1973
1974 !###############################################################################
1975 pure module function get_params(this) result(params)
1976 !! Get learnable parameters
1977 implicit none
1978
1979 ! Arguments
1980 class(network_type), intent(in) :: this
1981 !! Instance of network
1982 real(real32), dimension(this%num_params) :: params
1983 !! Parameters
1984
1985 ! Local variables
1986 integer :: l, i, start_idx, end_idx
1987 !! Loop index
1988
1989 start_idx = 0
1990 end_idx = 0
1991 do l = 1, this%num_layers
1992 select type(current => this%model(l)%layer)
1993 class is(learnable_layer_type)
1994 do i = 1, size(current%params)
1995 start_idx = end_idx + 1
1996 end_idx = end_idx + size(current%params(i)%val, 1)
1997 params(start_idx:end_idx) = current%params(i)%val(:,1)
1998 end do
1999 end select
2000 end do
2001
2002 end function get_params
2003 !###############################################################################
2004
2005
2006 !###############################################################################
2007 module subroutine set_params(this, params)
2008 !! Set learnable parameters
2009 implicit none
2010
2011 ! Arguments
2012 class(network_type), intent(inout) :: this
2013 !! Instance of network
2014 real(real32), dimension(this%num_params), intent(in) :: params
2015 !! Parameters
2016
2017 ! Local variables
2018 integer :: l, i, start_idx, end_idx
2019 !! Loop index
2020
2021 start_idx = 0
2022 end_idx = 0
2023 do l = 1, this%num_layers
2024 select type(current => this%model(l)%layer)
2025 class is(learnable_layer_type)
2026 do i = 1, size(current%params)
2027 start_idx = end_idx + 1
2028 end_idx = end_idx + size(current%params(i)%val, 1)
2029 current%params(i)%val(:,1) = params(start_idx:end_idx)
2030 end do
2031 ! call current%set_params(params(start_idx:end_idx))
2032 end select
2033 end do
2034
2035 end subroutine set_params
2036 !###############################################################################
2037
2038
2039 !###############################################################################
2040 pure module function get_gradients(this) result(gradients)
2041 !! Get gradients
2042 implicit none
2043
2044 ! Arguments
2045 class(network_type), intent(in) :: this
2046 !! Instance of network
2047 real(real32), dimension(this%num_params) :: gradients
2048 !! Gradients
2049
2050 ! Local variables
2051 integer :: l, i, start_idx, end_idx
2052 !! Loop index
2053
2054 start_idx = 0
2055 end_idx = 0
2056 do l = 1, this%num_layers
2057 select type(current => this%model(l)%layer)
2058 class is(learnable_layer_type)
2059 do i = 1, size(current%params)
2060 if(associated(current%params(i)%grad))then
2061 start_idx = end_idx + 1
2062 end_idx = end_idx + size(current%params(i)%val, 1)
2063 gradients(start_idx:end_idx) = [ &
2064 sum(current%params(i)%grad%val, dim=2) / &
2065 real(size(current%params(i)%grad%val, dim=2), real32) &
2066 ]
2067 end if
2068 end do
2069 end select
2070 end do
2071 call this%optimiser%clip_dict%apply(size(gradients),gradients)
2072
2073 end function get_gradients
2074 !###############################################################################
2075
2076
2077 !###############################################################################
2078 module subroutine set_gradients(this, gradients)
2079 !! Set gradients
2080 implicit none
2081
2082 ! Arguments
2083 class(network_type), intent(inout) :: this
2084 !! Instance of network
2085 real(real32), dimension(..), intent(in) :: gradients
2086 !! Gradients
2087
2088 ! Local variables
2089 integer :: l, start_idx, end_idx
2090 !! Loop index
2091
2092 start_idx = 0
2093 end_idx = 0
2094 do l = 1, this%num_layers
2095 select type(current => this%model(l)%layer)
2096 class is(learnable_layer_type)
2097 start_idx = end_idx + 1
2098 end_idx = end_idx + current%num_params
2099 select rank(gradients)
2100 rank(0)
2101 call current%set_gradients(gradients)
2102 rank(1)
2103 call current%set_gradients(gradients(start_idx:end_idx))
2104 end select
2105 end select
2106 end do
2107
2108 end subroutine set_gradients
2109 !###############################################################################
2110
2111
2112 !###############################################################################
2113 module subroutine reset_gradients(this)
2114 !! Reset gradients
2115 implicit none
2116
2117 ! Arguments
2118 class(network_type), intent(inout) :: this
2119 !! Instance of network
2120
2121 ! Local variables
2122 integer :: l, i
2123 !! Loop index
2124
2125 do l = 1, this%num_layers
2126 select type(current => this%model(l)%layer)
2127 class is(learnable_layer_type)
2128 do i = 1, size(current%params)
2129 call current%params(i)%zero_grad()
2130 end do
2131 end select
2132 end do
2133
2134 end subroutine reset_gradients
2135 !###############################################################################
2136
2137
2138 !###############################################################################
2139 module function get_output_shape(this) result(output_shape)
2140 !! Get the output of the network
2141 implicit none
2142
2143 ! Arguments
2144 class(network_type), intent(in) :: this
2145 !! Instance of network
2146 integer, dimension(2) :: output_shape
2147 !! Output shape
2148
2149 ! Local variables
2150 integer :: i, layer_idx
2151 !! Loop indices
2152
2153
2154 ! array data: [ layer idx, empty ]
2155 ! graph data: [ vertex/edge idx, sample idx]
2156
2157 if(this%use_graph_output)then
2158 output_shape = [2, this%batch_size]
2159 do i = 1, size(this%leaf_vertices,1), 1
2160 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2161 if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then
2162 call stop_program( &
2163 "Inconsistent batch size in output layers" &
2164 )
2165 return
2166 end if
2167 output_shape(1) = output_shape(1) + &
2168 size( this%model(layer_idx)%layer%output, 1 )
2169 end do
2170 else
2171 output_shape = [0, 1]
2172 do i = 1, size(this%leaf_vertices,1)
2173 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2174 if(size(this%model(layer_idx)%layer%output,2).ne.1)then
2175 call stop_program( &
2176 "Inconsistent size of dimension 2 in output layers" &
2177 )
2178 return
2179 end if
2180 output_shape(1) = &
2181 output_shape(1) + size( this%model(layer_idx)%layer%output, 1 )
2182 end do
2183 end if
2184
2185 end function get_output_shape
2186 !-------------------------------------------------------------------------------
2187 module function get_output(this) result(output)
2188 !! Get the output of the network
2189 implicit none
2190
2191 ! Arguments
2192 class(network_type), intent(in) :: this
2193 !! Instance of network
2194 type(array_type), dimension(:,:), allocatable :: output
2195 !! Output
2196
2197 ! Local variables
2198 integer :: i, start_idx, end_idx, layer_idx, output_id
2199 !! Loop indices
2200 integer, dimension(2) :: output_shape
2201 !! Output shape
2202 integer, dimension(this%num_outputs) :: output_ids
2203 !! Output IDs
2204
2205
2206 ! array data: [ layer idx, empty ]
2207 ! graph data: [ vertex/edge idx, sample idx]
2208
2209 if(this%use_graph_output)then
2210 output_shape = [2, this%batch_size]
2211 do i = 1, size(this%leaf_vertices,1), 1
2212 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2213 if(size(this%model(layer_idx)%layer%output,2).ne.this%batch_size)then
2214 call stop_program( &
2215 "Inconsistent batch size in output layers" &
2216 )
2217 return
2218 end if
2219 output_id = this%model(layer_idx)%layer%id
2220 output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 )
2221 output_shape(1) = output_shape(1) + output_ids(output_id)
2222 end do
2223 allocate(output(output_shape(1), output_shape(2)))
2224 do i = 1, size(this%leaf_vertices,1)
2225 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2226 output_id = sum(output_ids(1:this%model(layer_idx)%layer%id-1)) + 1
2227 output(output_id,:) = this%model(layer_idx)%layer%output(1,:)
2228 if(output_ids(this%model(layer_idx)%layer%id).gt.1)then
2229 output(output_id+1,:) = this%model(layer_idx)%layer%output(2,:)
2230 end if
2231 end do
2232 else
2233 output_shape = [0, 1]
2234 do i = 1, size(this%leaf_vertices,1)
2235 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2236 if(size(this%model(layer_idx)%layer%output,2).ne.1)then
2237 call stop_program( &
2238 "Inconsistent size of dimension 2 in output layers" &
2239 )
2240 return
2241 end if
2242 output_shape(1) = &
2243 output_shape(1) + size( this%model(layer_idx)%layer%output, 1 )
2244 output_id = this%model(layer_idx)%layer%id
2245 output_ids(output_id) = size( this%model(layer_idx)%layer%output, 1 )
2246 end do
2247 allocate(output(output_shape(1), output_shape(2)))
2248 start_idx = 1
2249 end_idx = 0
2250 do i = 1, size(this%leaf_vertices,1)
2251 layer_idx = this%auto_graph%vertex(this%leaf_vertices(i))%id
2252 output_id = this%model(layer_idx)%layer%id
2253 end_idx = end_idx + output_ids(output_id)
2254 output(start_idx:end_idx,1) = this%model(layer_idx)%layer%output(:,1)
2255 start_idx = end_idx + 1
2256 end do
2257 end if
2258
2259 end function get_output
2260 !-------------------------------------------------------------------------------
2261 module subroutine extract_output_real(this, output)
2262 !! Get the output of the network as real array
2263 implicit none
2264
2265 ! Arguments
2266 class(network_type), intent(in) :: this
2267 ! Instance of network
2268 real(real32), dimension(..), allocatable, intent(out) :: output
2269 !! Output
2270
2271 ! Local variables
2272 integer :: layer_id
2273 !! Layer ID
2274 character(len=10) :: rank_str
2275 !! String for rank
2276
2277 ! check if number of leaf vertices is 1
2278 if(size(this%leaf_vertices,1).gt.1)then
2279 call print_warning("Output extraction to real array only works for single &
2280 &output networks")
2281 return
2282 end if
2283
2284 ! Get output from the first (and only) leaf vertex
2285 layer_id = this%auto_graph%vertex(this%leaf_vertices(1))%id
2286 call this%model(layer_id)%layer%output(1,1)%extract(output)
2287
2288 end subroutine extract_output_real
2289 !###############################################################################
2290
2291
2292 !##############################################################################!
2293 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
2294 !##############################################################################!
2295
2296
2297 !###############################################################################
2298 module function accuracy_eval(this, output, start_index, end_index) &
2299 result(accuracy)
2300 !! Get the loss for the output
2301 implicit none
2302
2303 ! Arguments
2304 class(network_type), intent(in) :: this
2305 !! Instance of network
2306 class(*), dimension(:,:), intent(in) :: output
2307 !! Output
2308 integer, intent(in) :: start_index, end_index
2309 !! Start and end batch indices
2310
2311 real(real32) :: accuracy
2312 !! Loss value
2313
2314 ! Local variables
2315 integer :: s, s_idx
2316 !! Loop index
2317
2318 accuracy = 0._real32
2319 select type(output)
2320 type is(graph_type)
2321 do s = start_index, end_index, 1
2322 s_idx = s - start_index + 1
2323 accuracy = accuracy + sum( this%get_accuracy( &
2324 this%model(this%leaf_vertices(1))%layer%output(1,s_idx)%val, &
2325 output(1,s)%vertex_features &
2326 ) ) / output(1,s)%num_vertices
2327 if( &
2328 this%model(this%leaf_vertices(1))%layer%output_shape(2).gt.0 &
2329 )then
2330 accuracy = accuracy + sum( this%get_accuracy( &
2331 this%model(this%leaf_vertices(1))%layer%output(2,s_idx)%val, &
2332 output(1,s)%edge_features &
2333 ) ) / output(1,s)%num_edges
2334 end if
2335 end do
2336 type is(real(real32))
2337 accuracy = sum( &
2338 this%get_accuracy( &
2339 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
2340 output(:,start_index:end_index:1) &
2341 ))
2342 type is(integer)
2343 accuracy = sum( &
2344 this%get_accuracy( &
2345 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
2346 real(output(:,start_index:end_index:1),real32) &
2347 ))
2348 class is(array_type)
2349 accuracy = sum( &
2350 this%get_accuracy( &
2351 this%model(this%leaf_vertices(1))%layer%output(1,1)%val, &
2352 output(1,1)%val(:,start_index:end_index:1) &
2353 ))
2354 end select
2355
2356 end function accuracy_eval
2357 !###############################################################################
2358
2359
2360 !###############################################################################
2361 module function loss_eval(this, start_index, end_index) result(loss)
2362 !! Get the loss for the output
2363 implicit none
2364
2365 ! Arguments
2366 class(network_type), intent(inout), target :: this
2367 !! Instance of network
2368 integer, intent(in) :: start_index, end_index
2369 !! Start and end batch indices
2370
2371 type(array_type), pointer :: loss
2372 !! Loss value
2373
2374 ! Local variables
2375 integer :: i, s
2376 !! Loop index
2377 type(array_type), pointer :: expected(:,:), predicted(:,:)
2378
2379
2380 if(this%use_graph_output)then
2381 expected(1:2, 1: end_index - start_index + 1) => &
2382 this%expected_array( :, start_index:end_index )
2383 else
2384 allocate(expected(size(this%expected_array,1), size(this%expected_array,2)))
2385 do s = 1, size(this%expected_array,2)
2386 do i = 1, size(this%expected_array,1)
2387 call expected(i,s)%allocate( &
2388 array_shape = [ &
2389 this%expected_array(i,s)%shape, &
2390 size(this%expected_array(i,s)%val,2) &
2391 ] &
2392 )
2393 expected(i,s)%val = this%expected_array(i,s)%val(:, &
2394 start_index:end_index:1)
2395 end do
2396 end do
2397 end if
2398
2399 predicted => this%model(this%leaf_vertices(1))%layer%output
2400 loss => this%loss%compute( &
2401 predicted, &
2402 expected &
2403 )
2404
2405 end function loss_eval
2406 !###############################################################################
2407
2408
2409 !##############################################################################!
2410 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
2411 !##############################################################################!
2412
2413
2414 !###############################################################################
2415 module subroutine forward_generic2d(this, input)
2416 !! Forward pass for array derived type input
2417 implicit none
2418
2419 ! Arguments
2420 class(network_type), intent(inout), target :: this
2421 !! Instance of network
2422 class(*), dimension(:,:), intent(in) :: input
2423 !! Input
2424
2425 ! Local variables
2426 integer :: l, i, j, vertex_idx, layer_id, parent_id
2427 !! Loop index and vertex index
2428 integer :: input_idx
2429 !! Index of input layer
2430 integer :: num_input_layers
2431 !! Number of input layers
2432 type(array_type), pointer :: input_ptr(:,:) => null()
2433 type(array_ptr_type), dimension(:), allocatable :: input_list
2434
2435
2436 select type(input)
2437 type is(graph_type)
2438 do j = 1, this%batch_size
2439 if(any(input(1,j)%adj_ja(1,:).gt.input(1,j)%num_vertices))then
2440 call stop_program( &
2441 "input graph has more vertices than expected" &
2442 )
2443 end if
2444 end do
2445 end select
2446 ! Forward pass
2447 !---------------------------------------------------------------------------
2448 do l = 1, size(this%vertex_order,1)
2449 vertex_idx = this%vertex_order(l)
2450 layer_id = this%auto_graph%vertex(vertex_idx)%id
2451 num_input_layers = count(this%auto_graph%adjacency(:,vertex_idx).gt.0)
2452 if(num_input_layers.eq.0)then
2453 select type(layer => this%model(layer_id)%layer)
2454 class is(input_layer_type)
2455 select type(input)
2456 type is(graph_type)
2457 call layer%set_input_graph( [ input(layer%index, :) ] )
2458 cycle
2459 class is(array_type)
2460 call layer%forward(input(layer%index:layer%index,:))
2461 do concurrent(i=1:size(layer%output,1), j=1:size(layer%output,2))
2462 call layer%output(i,j)%set_requires_grad(.false.)
2463 end do
2464 cycle
2465 type is(real(real32))
2466 allocate(input_ptr(1,1))
2467 call input_ptr(1,1)%allocate(shape(input))
2468 call input_ptr(1,1)%set(input)
2469 call layer%forward(input_ptr)
2470 call layer%output(1,1)%set_requires_grad(.false.)
2471 deallocate(input_ptr)
2472 input_ptr => null()
2473 cycle
2474 class default
2475 call stop_program( &
2476 "input type for layer "// &
2477 trim(layer%name) // &
2478 " is not supported" &
2479 )
2480 end select
2481 class default
2482 return
2483 end select
2484 elseif(num_input_layers.eq.1)then
2485 j = maxloc(this%auto_graph%adjacency(:,vertex_idx),dim=1)
2486 input_idx = findloc(this%root_vertices, j, dim=1)
2487 parent_id = this%auto_graph%vertex(j)%id
2488 input_ptr => this%model(parent_id)%layer%output
2489 select type(input)
2490 type is(graph_type)
2491 call this%model(layer_id)%layer%set_graph( [ input(1,:) ] )
2492 end select
2493 else
2494 allocate(input_list(num_input_layers))
2495 i = 0
2496 do j = 1, size(this%vertex_order,1)
2497 if(this%auto_graph%adjacency(j,vertex_idx).gt.0)then
2498 i = i + 1
2499 parent_id = this%auto_graph%vertex(j)%id
2500 input_list(i)%array => this%model(parent_id)%layer%output
2501 end if
2502 end do
2503 end if
2504
2505 select type(layer => this%model(layer_id)%layer)
2506 class is(merge_layer_type)
2507 call layer%combine(input_list)
2508 deallocate(input_list)
2509 class default
2510 call layer%forward(input_ptr)
2511 input_ptr => null()
2512 end select
2513
2514 end do
2515
2516 end subroutine forward_generic2d
2517 !-------------------------------------------------------------------------------
2518 module function forward_eval(this, input) result(output)
2519 !! Forward pass for evaluation
2520 implicit none
2521
2522 ! Arguments
2523 class(network_type), intent(inout), target :: this
2524 !! Instance of network
2525 class(*), dimension(:,:), intent(in) :: input
2526 !! Input
2527
2528 type(array_type), pointer :: output(:,:)
2529 !! Output
2530
2531 call this%forward(input)
2532 output => this%model(this%leaf_vertices(1))%layer%output
2533
2534 end function forward_eval
2535 !-------------------------------------------------------------------------------
2536 module function forward_eval_multi(this, input) result(output)
2537 !! Forward pass for evaluation
2538 implicit none
2539
2540 ! Arguments
2541 class(network_type), intent(inout), target :: this
2542 !! Instance of network
2543 class(*), dimension(:,:), intent(in) :: input
2544 !! Input
2545
2546 type(array_ptr_type), pointer :: output(:)
2547 !! Output
2548
2549 ! Local variables
2550 integer :: l
2551 !! Loop index
2552
2553 call this%forward(input)
2554 allocate(output(size(this%leaf_vertices,1)))
2555 do l = 1, size(this%leaf_vertices,1)
2556 output(l)%array => this%model(this%leaf_vertices(l))%layer%output
2557 end do
2558
2559 end function forward_eval_multi
2560 !###############################################################################
2561
2562
2563 !###############################################################################
2564 module subroutine update(this)
2565 !! Update the network
2566 implicit none
2567
2568 ! Arguments
2569 class(network_type), intent(inout) :: this
2570 !! Instance of network
2571 real(real32), dimension(this%num_params) :: params, gradients
2572 !! Parameters and gradients
2573
2574 ! Local variables
2575 integer :: l, i, start_idx, end_idx
2576 !! Loop index
2577
2578
2579 !---------------------------------------------------------------------------
2580 ! Increment optimiser iteration counter
2581 !---------------------------------------------------------------------------
2582 if(this%optimiser%lr_decay%iterate_per_epoch)then
2583 if(this%epoch.gt.this%optimiser%epoch)then
2584 this%optimiser%epoch = this%epoch
2585 this%optimiser%iter = this%optimiser%iter + 1
2586 end if
2587 else
2588 this%optimiser%iter = this%optimiser%iter + 1
2589 end if
2590
2591
2592 !---------------------------------------------------------------------------
2593 ! Get learnable parameters and gradients
2594 !---------------------------------------------------------------------------
2595 start_idx = 0
2596 end_idx = 0
2597 do l = 1, this%num_layers
2598 select type(current => this%model(l)%layer)
2599 class is(learnable_layer_type)
2600 do i = 1, size(current%params)
2601 start_idx = end_idx + 1
2602 end_idx = end_idx + size(current%params(i)%val, 1)
2603 params(start_idx:end_idx) = current%params(i)%val(:,1)
2604 if(.not.associated(current%params(i)%grad))then
2605 call stop_program( &
2606 "Gradient not allocated for parameters in layer "// &
2607 trim(current%name) // &
2608 "." &
2609 )
2610 end if
2611 select case(size(current%params(i)%grad%val,2))
2612 case(1)
2613 gradients(start_idx:end_idx) = current%params(i)%grad%val(:,1)
2614 case default
2615 gradients(start_idx:end_idx) = [ &
2616 sum(current%params(i)%grad%val, dim=2) / &
2617 real(size(current%params(i)%grad%val, dim=2), real32) &
2618 ]
2619 end select
2620 end do
2621 end select
2622 end do
2623 ! have an if statement of whether to apply clipping to to gradients of
2624 ! each layer individually or collectively to the all gradients at once
2625 call this%optimiser%clip_dict%apply(size(gradients),gradients)
2626
2627 !---------------------------------------------------------------------------
2628 ! Update layers of learnable layer types
2629 !---------------------------------------------------------------------------
2630 call this%optimiser%minimise(params, gradients)
2631 call this%set_params(params)
2632 call this%reset_gradients()
2633
2634 end subroutine update
2635 !###############################################################################
2636
2637
2638 !##############################################################################!
2639 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
2640 !##############################################################################!
2641
2642
2643 !###############################################################################
2644 module subroutine nullify_graph(this)
2645 !! Nullify the input graph
2646 implicit none
2647
2648 ! Arguments
2649 class(network_type), intent(inout) :: this
2650 !! Instance of network
2651
2652 ! Local variables
2653 integer :: l
2654
2655 do l = 1, this%num_layers
2656 call this%model(l)%layer%nullify_graph()
2657 end do
2658
2659 end subroutine nullify_graph
2660 !###############################################################################
2661
2662
2663 !##############################################################################!
2664 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
2665 !##############################################################################!
2666
2667
2668 !###############################################################################
2669 module function save_input_to_network( this, input ) result(num_samples)
2670 !! Save input to network
2671 implicit none
2672
2673 ! Arguments
2674 class(network_type), intent(inout) :: this
2675 !! Instance of network
2676 class(*), dimension(..), intent(in) :: input
2677 !! Input
2678
2679 integer :: num_samples
2680 !! Number of samples
2681
2682 ! Local variables
2683 integer :: i, j, l, ip, input_rank, num_inputs
2684 !! Loop index
2685 integer :: num_input_layers
2686 !! Number of input layers
2687 logical :: l_valid_rank_type
2688 !! Boolean whether rank type is valid
2689 character(256) :: err_msg
2690 !! Error message
2691
2692 num_samples = get_num_samples(this, input)
2693 num_input_layers = size(this%root_vertices, 1)
2694 if(allocated(this%input_array))then
2695 do i = 1, size(this%input_array, 1)
2696 do j = 1, size(this%input_array, 2)
2697 call this%input_array(i,j)%deallocate()
2698 end do
2699 end do
2700 deallocate(this%input_array)
2701 end if
2702 if(allocated(this%input_graph)) deallocate(this%input_graph)
2703
2704 ! Determine the rank of the input
2705 !---------------------------------------------------------------------------
2706 select rank(input)
2707 rank(0)
2708 rank(1)
2709 rank(2)
2710 select type(input)
2711 class is(array_type)
2712 num_inputs = size(input(1,1)%val, 1)
2713 allocate(this%input_array(size(input,1), size(input,2)))
2714 do i = 1, size(input,1)
2715 do j = 1, size(input,2)
2716 call this%input_array(i,j)%assign_shallow(input(i,j))
2717 end do
2718 end do
2719 return
2720 class default
2721 input_rank = rank(input)
2722 num_inputs = size(input) / num_samples
2723 allocate(this%input_array(1,1))
2724 call this%input_array(1,1)%allocate(array_shape=[num_inputs, num_samples])
2725 end select
2726 rank default
2727 input_rank = rank(input)
2728 num_inputs = size(input) / num_samples
2729 allocate(this%input_array(1,1))
2730 call this%input_array(1,1)%allocate(array_shape=shape(input))
2731 end select
2732 l_valid_rank_type = .false.
2733
2734
2735 ! Process input based on its rank
2736 !---------------------------------------------------------------------------
2737 rank_select: select rank(input)
2738 rank(0)
2739 select type(input)
2740 type is(real); exit rank_select
2741 class default; l_valid_rank_type = .true.
2742 end select
2743 if(num_input_layers.ne.1)then
2744 call stop_program( &
2745 "number of input arrays does not match expected number of &
2746 &input layers" &
2747 )
2748 return
2749 end if
2750 select type(input)
2751 class is(array_type)
2752 allocate(this%input_array(1,1))
2753 call handle_array_type(input, this%input_array(1,1), num_samples)
2754 type is(array_ptr_type)
2755 allocate(this%input_array(size(input%array,1), size(input%array,2)))
2756 do i = 1, size(input%array,1)
2757 do j = 1, size(input%array,2)
2758 call handle_array_type( &
2759 input%array(i,j), this%input_array(i,j), num_samples &
2760 )
2761 end do
2762 end do
2763 end select
2764 rank(1)
2765 select type(input)
2766 type is(real(real32))
2767 exit rank_select
2768 type is(graph_type)
2769 allocate(this%input_graph(num_input_layers, num_samples))
2770 this%input_graph(1,:) = input(:)
2771 return
2772 class default
2773 l_valid_rank_type = .true.
2774 end select
2775 if(size(input,1).ne.num_input_layers)then
2776 call stop_program( &
2777 "number of input arrays does not match expected number of &
2778 &input layers" &
2779 )
2780 return
2781 end if
2782 select type(input)
2783 class is(array_type)
2784 allocate(this%input_array(1,size(input,1)))
2785 do l = 1, size(input,1)
2786 call handle_array_type(input(l), this%input_array(1,l), num_samples)
2787 end do
2788 type is(array_ptr_type)
2789 call stop_program("Use of array_ptr_type with rank 1 input not yet supported")
2790 return
2791 ! ip = 0
2792 ! do l = 1, size(input,1)
2793 ! do i = 1, size(input%array,1)
2794 ! ip = ip + 1
2795 ! do j = 1, size(input%array,2)
2796 ! call handle_array_type( &
2797 ! input(l)%array(i,j), this%input_array(ip,j), num_samples &
2798 ! )
2799 ! end do
2800 ! end do
2801 ! end do
2802 end select
2803 rank(2)
2804 select type(input)
2805 type is(real(real32))
2806 this%input_array(1,1)%val = reshape(input, [num_inputs, num_samples])
2807 l_valid_rank_type = .true.
2808 type is(graph_type)
2809 num_samples = size(input, dim=2)
2810 allocate(this%input_graph(num_input_layers, num_samples))
2811 this%input_graph(:,:) = input(:,:)
2812 return
2813 type is(array_type)
2814 call stop_program("SHOULD NOT GET HERE")
2815 this%input_array = input
2816 l_valid_rank_type = .true.
2817 end select
2818 rank(3)
2819 select type(input)
2820 type is(real(real32))
2821 call this%input_array(1,1)%set(input)
2822 l_valid_rank_type = .true.
2823 end select
2824 rank(4)
2825 select type(input)
2826 type is(real(real32))
2827 call this%input_array(1,1)%set(input)
2828 l_valid_rank_type = .true.
2829 end select
2830 rank(5)
2831 select type(input)
2832 type is(real(real32))
2833 call this%input_array(1,1)%set(input)
2834 l_valid_rank_type = .true.
2835 end select
2836 end select rank_select
2837
2838 if(.not.l_valid_rank_type)then
2839 write(err_msg,'("Unknown input type for rank ",I0)') input_rank
2840 call stop_program(err_msg)
2841 return
2842 end if
2843
2844 contains
2845
2846 6 function get_num_samples(network, input) result(num_samples)
2847 implicit none
2848 !! Get the number of samples in the input
2849
2850 ! Arguments
2851 type(network_type), intent(in) :: network
2852 !! Instance of network
2853 class(*), dimension(..), intent(in) :: input
2854 !! Input
2855 integer :: num_samples
2856 !! Number of samples
2857
2858 ! Local variables
2859 integer :: layer_id
2860 !! Layer ID
2861 logical :: use_graph_input
2862 !! Whether to use graph input
2863
2864 6 num_samples = 0
2865
4/8
✗ Branch 0 (2→3) not taken.
✓ Branch 1 (2→4) taken 6 times.
✗ Branch 2 (4→5) not taken.
✓ Branch 3 (4→6) taken 6 times.
✗ Branch 4 (6→7) not taken.
✓ Branch 5 (6→8) taken 6 times.
✗ Branch 6 (8→9) not taken.
✓ Branch 7 (8→10) taken 6 times.
6 layer_id = network%auto_graph%vertex(network%root_vertices(1))%id
2866
2/4
✗ Branch 0 (10→11) not taken.
✓ Branch 1 (10→12) taken 6 times.
✗ Branch 2 (12→13) not taken.
✓ Branch 3 (12→14) taken 6 times.
6 use_graph_input = network%model(layer_id)%layer%use_graph_input
2867 select rank(input)
2868 rank(0)
2869 select type(input)
2870 class is(array_type)
2871 num_samples = size(input%val, 2)
2872 class is(array_ptr_type)
2873 num_samples = size(input%array(1,1)%val, 2)
2874 class default
2875 call stop_program("Unknown input type in get_num_samples for rank 0")
2876 return
2877 end select
2878 rank(1)
2879 select type(input)
2880 class is(array_type)
2881 if(use_graph_input)then
2882 num_samples = size(input)
2883 else
2884 num_samples = size(input(1)%val, 2)
2885 end if
2886 class is(array_ptr_type)
2887 if(use_graph_input)then
2888 num_samples = size(input(1)%array, 2)
2889 else
2890 num_samples = size(input(1)%array(1,1)%val, 2)
2891 end if
2892 class is(graph_type)
2893 num_samples = size(input, dim=1)
2894 type is(real)
2895 num_samples = size(input, rank(input))
2896 class default
2897 call stop_program("Unknown input type in get_num_samples for rank 1")
2898 return
2899 end select
2900 rank(2)
2901
1/2
✓ Branch 0 (111→112) taken 6 times.
✗ Branch 1 (111→114) not taken.
6 select type(input)
2902 class is(array_type)
2903
1/2
✗ Branch 0 (92→93) not taken.
✓ Branch 1 (92→94) taken 1 times.
2 if(use_graph_input)then
2904 num_samples = size(input, 2)
2905 else
2906
4/8
✗ Branch 0 (94→95) not taken.
✓ Branch 1 (94→96) taken 1 times.
✗ Branch 2 (96→97) not taken.
✓ Branch 3 (96→98) taken 1 times.
✗ Branch 4 (98→99) not taken.
✓ Branch 5 (98→100) taken 1 times.
✗ Branch 6 (100→101) not taken.
✓ Branch 7 (100→102) taken 1 times.
1 num_samples = size(input(1,1)%val, 2)
2907 end if
2908 class is(graph_type)
2909 4 num_samples = size(input, dim=2)
2910 type is(real)
2911 1 num_samples = size(input, rank(input))
2912 class default
2913 call stop_program("Unknown input type in get_num_samples for rank 2")
2914 return
2915 end select
2916 rank(3)
2917 select type(input)
2918 type is(real)
2919 num_samples = size(input, rank(input))
2920 class default
2921 call stop_program("Unknown input type in get_num_samples for rank 3")
2922 return
2923 end select
2924 rank(4)
2925 select type(input)
2926 type is(real)
2927 num_samples = size(input, rank(input))
2928 class default
2929 call stop_program("Unknown input type in get_num_samples for rank 4")
2930 return
2931 end select
2932 rank(5)
2933 select type(input)
2934 type is(real)
2935 num_samples = size(input, rank(input))
2936 class default
2937 call stop_program("Unknown input type in get_num_samples for rank 5")
2938 return
2939 end select
2940 rank default
2941 call stop_program("Unknown input rank in get_num_samples")
2942 return
2943 end select
2944
2945 6 end function get_num_samples
2946
2947
2948 subroutine handle_array_type(input, output, num_samples)
2949 !! Handle array type input
2950
2951 ! Arguments
2952 class(array_type), intent(in) :: input
2953 !! Input
2954 type(array_type), intent(out) :: output
2955 !! Output
2956 integer, intent(in) :: num_samples
2957 !! Number of samples
2958
2959 if(size(input%val,2).ne.num_samples)then
2960 call stop_program("number of samples in input arrays do not match")
2961 return
2962 end if
2963 call output%allocate( array_shape = &
2964 [ product(input%shape(1:input%rank)), num_samples ] &
2965 )
2966 output%val = input%val
2967 end subroutine handle_array_type
2968
2969 end function save_input_to_network
2970 !-------------------------------------------------------------------------------
2971 module subroutine save_output_to_network( this, output )
2972 !! Save output to network
2973 implicit none
2974
2975 ! Arguments
2976 class(network_type), intent(inout) :: this
2977 !! Instance of network
2978 class(*), dimension(:,:), intent(in) :: output
2979 !! Output
2980
2981 ! Local variables
2982 integer :: i, j, s
2983 !! Loop indices
2984
2985 if(allocated(this%expected_array))then
2986 do i = 1, size(this%expected_array, 1)
2987 do j = 1, size(this%expected_array, 2)
2988 call this%expected_array(i,j)%deallocate()
2989 end do
2990 end do
2991 deallocate(this%expected_array)
2992 end if
2993
2994 select type(output)
2995 type is(graph_type)
2996 allocate(this%expected_array(2,size(output,2)))
2997 do s = 1, size(output,2)
2998 if(this%expected_array(1,s)%allocated) &
2999 call this%expected_array(1,s)%deallocate()
3000 if(this%expected_array(2,s)%allocated) &
3001 call this%expected_array(2,s)%deallocate()
3002 call this%expected_array(1,s)%allocate( &
3003 array_shape = [ &
3004 output(1,s)%num_vertex_features, output(1,s)%num_vertices &
3005 ] &
3006 )
3007 call this%expected_array(1,s)%zero_grad()
3008 call this%expected_array(1,s)%set_requires_grad(.false.)
3009 call this%expected_array(1,s)%set( output(1,s)%vertex_features )
3010 this%expected_array(1,s)%is_temporary = .false.
3011 if(output(1,s)%num_edge_features.le.0) cycle
3012 call this%expected_array(2,s)%allocate( &
3013 array_shape = [ &
3014 output(1,s)%num_edge_features, output(1,s)%num_edges &
3015 ] &
3016 )
3017 call this%expected_array(2,s)%set_requires_grad(.false.)
3018 call this%expected_array(2,s)%set( output(1,s)%edge_features )
3019 this%expected_array(2,s)%is_temporary = .false.
3020 end do
3021 class is(array_type)
3022 allocate(this%expected_array(size(output,1),size(output,2)))
3023 do s = 1, size(output,2)
3024 do i = 1, size(output,1)
3025 if(this%expected_array(i,s)%allocated) &
3026 call this%expected_array(i,s)%deallocate()
3027 call this%expected_array(i,s)%allocate( &
3028 array_shape = [ &
3029 output(i,s)%shape, size(output(i,s)%val,2) &
3030 ] &
3031 )
3032 call this%expected_array(i,s)%set_requires_grad(.false.)
3033 call this%expected_array(i,s)%set( output(i,s)%val )
3034 this%expected_array(i,s)%is_temporary = .false.
3035 end do
3036 end do
3037 type is(real)
3038 allocate(this%expected_array(1,1))
3039 if(this%expected_array(1,1)%allocated) &
3040 call this%expected_array(1,1)%deallocate()
3041 call this%expected_array(1,1)%allocate( &
3042 array_shape = [ size(output,1), size(output,2) ] &
3043 )
3044 call this%expected_array(1,1)%set_requires_grad(.false.)
3045 call this%expected_array(1,1)%set( output )
3046 this%expected_array(1,1)%is_temporary = .false.
3047 type is(integer)
3048 allocate(this%expected_array(1,1))
3049 if(this%expected_array(1,1)%allocated) &
3050 call this%expected_array(1,1)%deallocate()
3051 call this%expected_array(1,1)%allocate( &
3052 array_shape = [ size(output,1), size(output,2) ] &
3053 )
3054 call this%expected_array(1,1)%set_requires_grad(.false.)
3055 this%expected_array(1,1)%val = real(output, real32)
3056 this%expected_array(1,1)%is_temporary = .false.
3057 class default
3058 call stop_program("output type not supported in training")
3059 end select
3060
3061 end subroutine save_output_to_network
3062 !###############################################################################
3063
3064
3065 !##############################################################################!
3066 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
3067 !##############################################################################!
3068
3069
3070 !###############################################################################
3071 module subroutine train( &
3072 this, input, output, num_epochs, batch_size, &
3073 plateau_threshold, shuffle_batches, batch_print_step, verbose &
3074 )
3075 !! Train the network
3076 !!
3077 !! This function trains the network on the input data for a number of
3078 !! epochs. The input data is split into batches of size batch_size and
3079 !! the network is trained on each batch. The network is trained using
3080 !! the optimiser specified in the network object.
3081 use athena__tools_infile, only: stop_check
3082 implicit none
3083
3084 ! Arguments
3085 class(network_type), intent(inout) :: this
3086 !! Instance of network
3087 class(*), dimension(..), intent(in) :: input
3088 !! Input data
3089 class(*), dimension(:,:), intent(in) :: output
3090 !! Output data
3091 integer, intent(in) :: num_epochs
3092 !! Number of epochs
3093 integer, optional, intent(in) :: batch_size
3094 !! Batch size
3095 real(real32), optional, intent(in) :: plateau_threshold
3096 !! Plateau threshold
3097 logical, optional, intent(in) :: shuffle_batches
3098 !! Shuffle batches
3099 integer, optional, intent(in) :: batch_print_step
3100 !! Batch print step
3101 integer, optional, intent(in) :: verbose
3102 !! Verbosity level
3103
3104 ! Training parameters
3105 real(real32) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy
3106 !! Loss and accuracy
3107
3108 ! learning parameters
3109 integer :: l, num_samples
3110 !! Loop index
3111 integer :: num_batches
3112 !! Number of batches
3113 integer :: converged
3114 !! Convergence flag
3115 integer :: window_width
3116 !! Length of convergence check window
3117 integer :: verbose_
3118 !! Verbosity level
3119 integer :: batch_print_step_
3120 !! Batch print step
3121 real(real32) :: plateau_threshold_
3122 !! Plateau threshold
3123 logical :: shuffle_batches_
3124 !! Shuffle batches
3125
3126 ! Training loop variables
3127 integer :: epoch, batch, start_index, end_index
3128 !! Loop index
3129 integer, allocatable, dimension(:) :: batch_order
3130 !! Batch order
3131
3132 integer :: i, j, s, time, time_old, clock_rate
3133 !! Loop index
3134
3135 class(*), allocatable :: data_poly(:,:)
3136 type(array_type), pointer :: loss => null()
3137
3138 #ifdef _OPENMP
3139 type(network_type) :: this_copy
3140 !! Copy of network
3141 #endif
3142 ! integer :: timer_start = 0, timer_stop = 0, timer_sum = 0, timer_tot = 0
3143 ! integer :: forward_timer = 0, backward_timer = 0, update_timer = 0
3144
3145
3146 !---------------------------------------------------------------------------
3147 ! Check loss and accuracy methods are set
3148 !---------------------------------------------------------------------------
3149 if(.not.allocated(this%loss))then
3150 call stop_program("loss method not set")
3151 return
3152 end if
3153 if(.not.associated(this%get_accuracy))then
3154 call stop_program("accuracy method not set")
3155 return
3156 end if
3157
3158
3159 !---------------------------------------------------------------------------
3160 ! Initialise optional arguments
3161 !---------------------------------------------------------------------------
3162 verbose_ = 0
3163 batch_print_step_ = 20
3164 plateau_threshold_ = 0._real32
3165 shuffle_batches_ = .true.
3166 if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold
3167 if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches
3168 if(present(batch_print_step)) batch_print_step_ = batch_print_step
3169 if(present(verbose)) verbose_ = verbose
3170 if(present(batch_size)) this%batch_size = batch_size
3171
3172
3173 !---------------------------------------------------------------------------
3174 ! Initialise monitoring variables
3175 !---------------------------------------------------------------------------
3176 window_width = max(ceiling(500._real32/this%batch_size),1)
3177 do i = 1, size(this%metrics,dim=1)
3178 this%metrics(i)%window_width = window_width
3179 end do
3180
3181
3182 !---------------------------------------------------------------------------
3183 ! Save input and output to network
3184 !---------------------------------------------------------------------------
3185 num_samples = this%save_input( input )
3186 call this%save_output( output )
3187 if(size(output,2).ne.num_samples.and.this%use_graph_output)then
3188 call stop_program("number of samples in input and output do not match")
3189 return
3190 end if
3191
3192
3193 !---------------------------------------------------------------------------
3194 ! If parallel, initialise slices
3195 !---------------------------------------------------------------------------
3196 select type(output)
3197 type is(graph_type)
3198 num_batches = size(output,dim=2) / this%batch_size
3199 class is(array_type)
3200 if(this%use_graph_output)then
3201 num_batches = size(output,dim=2) / this%batch_size
3202 else
3203 num_batches = size(output(1,1)%val,dim=2) / this%batch_size
3204 end if
3205 class default
3206 num_batches = size(output,dim=2) / this%batch_size
3207 end select
3208 allocate(batch_order(num_batches))
3209 do batch = 1, num_batches
3210 batch_order(batch) = batch
3211 end do
3212
3213
3214 !---------------------------------------------------------------------------
3215 ! Set/reset batch size for training
3216 !---------------------------------------------------------------------------
3217 call this%set_batch_size(this%batch_size)
3218
3219
3220 !---------------------------------------------------------------------------
3221 ! Turn off inference booleans
3222 !---------------------------------------------------------------------------
3223 do l = 1, this%num_layers
3224 this%model(l)%layer%inference = .false.
3225 end do
3226
3227
3228 !---------------------------------------------------------------------------
3229 ! Query system clock
3230 !---------------------------------------------------------------------------
3231 call system_clock(time, count_rate = clock_rate)
3232
3233
3234 epoch_loop: do epoch = 1, num_epochs
3235 this%epoch = epoch
3236 !------------------------------------------------------------------------
3237 ! Shuffle batch order at the start of each epoch
3238 !------------------------------------------------------------------------
3239 if(shuffle_batches_)then
3240 call shuffle(batch_order)
3241 end if
3242
3243 avg_loss = 0._real32
3244 avg_accuracy = 0._real32
3245
3246 !------------------------------------------------------------------------
3247 ! Batch loop
3248 ! ... split data up into minibatches for training
3249 !------------------------------------------------------------------------
3250 batch_loop: do batch = 1, num_batches
3251
3252
3253 ! Set batch start and end index
3254 !---------------------------------------------------------------------
3255 start_index = (batch_order(batch) - 1) * this%batch_size + 1
3256 end_index = batch_order(batch) * this%batch_size
3257
3258
3259 ! Forward pass
3260 !---------------------------------------------------------------------
3261 ! call system_clock(timer_start)
3262 select case(this%use_graph_input)
3263 case(.true.)
3264 data_poly = get_sample( &
3265 this%input_graph, start_index, end_index, this%batch_size &
3266 )
3267 case default
3268 data_poly = get_sample( &
3269 this%input_array, start_index, end_index, this%batch_size, &
3270 as_graph = .false. &
3271 )
3272 end select
3273 call this%forward(data_poly)
3274 deallocate(data_poly)
3275 ! call system_clock(timer_stop)
3276 ! forward_timer = forward_timer + timer_stop - timer_start
3277
3278
3279 ! Backward pass
3280 !---------------------------------------------------------------------
3281 ! call system_clock(timer_start)
3282 loss => this%loss_eval(start_index, end_index)
3283 loss%is_temporary = .false.
3284 call loss%grad_reverse(reset_graph=.true.)
3285 ! call system_clock(timer_stop)
3286 ! backward_timer = backward_timer + timer_stop - timer_start
3287
3288
3289 ! Compute loss and accuracy (for monitoring)
3290 !---------------------------------------------------------------------
3291 batch_loss = sum(loss%val)
3292 batch_accuracy = this%accuracy_eval(output, start_index, end_index)
3293
3294
3295 ! Average metric over batch size and store
3296 ! Check metric convergence
3297 !---------------------------------------------------------------------
3298 avg_loss = avg_loss + batch_loss
3299 avg_accuracy = avg_accuracy + batch_accuracy
3300 call this%metrics(1)%append(batch_loss)
3301 call this%metrics(2)%append(batch_accuracy / this%batch_size)
3302 do i = 1, size(this%metrics,dim=1)
3303 call this%metrics(i)%check(plateau_threshold_, converged)
3304 if(converged.ne.0)then
3305 exit epoch_loop
3306 end if
3307 end do
3308
3309
3310 ! Update weights and biases using optimisation algorithm
3311 !---------------------------------------------------------------------
3312 ! call system_clock(timer_start)
3313 call this%update()
3314 ! call system_clock(timer_stop)
3315 ! update_timer = update_timer + timer_stop - timer_start
3316 call loss%nullify_graph()
3317 deallocate(loss)
3318 nullify(loss)
3319
3320
3321 ! Print batch results
3322 !---------------------------------------------------------------------
3323 if(abs(verbose_).gt.0.and.&
3324 (batch.eq.1.or.abs(mod(batch,batch_print_step_)).lt.1.E-6))then
3325 write(6,'("epoch=",I0,", batch=",I0,&
3326 &", learning_rate=",F0.3,", loss=",F0.3,", accuracy=",F0.3)' &
3327 ) &
3328 this%epoch, batch, &
3329 this%optimiser%lr_decay%get_lr( &
3330 this%optimiser%learning_rate, this%optimiser%iter &
3331 ), &
3332 avg_loss/batch, &
3333 avg_accuracy/(batch*this%batch_size)
3334 end if
3335
3336
3337 ! Time check
3338 !---------------------------------------------------------------------
3339 if(verbose_.eq.-2)then
3340 time_old = time
3341 call system_clock(time)
3342 write(*,'("time check: ",F5.3," seconds")') &
3343 real(time-time_old)/clock_rate
3344 time_old = time
3345 end if
3346
3347
3348 ! Check for user-name stop file
3349 !---------------------------------------------------------------------
3350 if(stop_check())then
3351 write(0,*) "STOPCAR ENCOUNTERED"
3352 write(0,*) "Exiting training loop..."
3353 exit epoch_loop
3354 end if
3355
3356 end do batch_loop
3357
3358
3359 ! Print epoch summary results
3360 !------------------------------------------------------------------------
3361 if(verbose_.eq.0)then
3362 write(6,'("epoch=",I0,&
3363 &", learning_rate=",F0.3,", val_loss=",F0.3,&
3364 &", val_accuracy=",F0.3)' &
3365 ) &
3366 this%epoch, &
3367 this%optimiser%lr_decay%get_lr( &
3368 this%optimiser%learning_rate, this%optimiser%iter &
3369 ), &
3370 this%metrics(1)%val, this%metrics(2)%val
3371 end if
3372
3373
3374 end do epoch_loop
3375
3376 ! write(*,*) "forward timer: ", real(forward_timer)/clock_rate
3377 ! write(*,*) "backward timer: ", real(backward_timer)/clock_rate
3378 ! write(*,*) "update timer: ", real(update_timer)/clock_rate
3379
3380 end subroutine train
3381 !###############################################################################
3382
3383
3384 !###############################################################################
3385 module subroutine test( &
3386 this, input, output, verbose &
3387 )
3388 !! Test the network
3389 implicit none
3390
3391 ! Arguments
3392 class(network_type), intent(inout) :: this
3393 !! Instance of network
3394 class(*), dimension(..), intent(in) :: input
3395 !! Input data
3396 class(*), dimension(:,:), intent(in) :: output
3397 !! Output data
3398 integer, optional, intent(in) :: verbose
3399 !! Verbosity level
3400
3401 ! Local variables
3402 integer :: l, sample, num_samples
3403 !! Loop index
3404 integer :: verbose_
3405 !! Verbosity level
3406 real(real32) :: acc_val, loss_val
3407 !! Loss and accuracy
3408 class(*), allocatable, dimension(:,:) :: data_poly
3409 !! Polymorphic data array
3410 type(array_type), pointer :: loss => null()
3411 !! Loss
3412
3413
3414 !---------------------------------------------------------------------------
3415 ! Initialise optional arguments
3416 !---------------------------------------------------------------------------
3417 if(present(verbose))then
3418 verbose_ = verbose
3419 else
3420 verbose_ = 0
3421 end if
3422
3423 do l = 1, size(this%metrics,dim=1)
3424 this%metrics(l)%val = 0._real32
3425 end do
3426 loss_val = 0._real32
3427 acc_val = 0._real32
3428
3429
3430 num_samples = this%save_input( input )
3431
3432
3433 !---------------------------------------------------------------------------
3434 ! Reset batch size for testing
3435 !---------------------------------------------------------------------------
3436 call this%set_batch_size(1)
3437
3438
3439 !---------------------------------------------------------------------------
3440 ! Turn on inference booleans
3441 !---------------------------------------------------------------------------
3442 do l = 1, this%num_layers
3443 this%model(l)%layer%inference = .true.
3444 end do
3445
3446
3447 !---------------------------------------------------------------------------
3448 ! Testing loop
3449 !---------------------------------------------------------------------------
3450 test_loop1: do sample = 1, num_samples
3451
3452 ! Forward pass
3453 !------------------------------------------------------------------------
3454 select case(this%use_graph_input)
3455 case(.true.)
3456 data_poly = get_sample( &
3457 this%input_graph, sample, sample, 1 &
3458 )
3459 case default
3460 data_poly = get_sample_array( &
3461 this%input_array, sample, sample, 1, &
3462 as_graph = .false. &
3463 )
3464 end select
3465 call this%forward(data_poly)
3466 deallocate(data_poly)
3467
3468
3469 ! Compute loss and accuracy (for monitoring)
3470 !------------------------------------------------------------------------
3471 loss => this%loss_eval(sample, sample)
3472 loss_val = sum(loss%val)
3473 call loss%nullify_graph()
3474 deallocate(loss)
3475 nullify(loss)
3476 acc_val = this%accuracy_eval(output, sample, sample)
3477
3478 this%metrics(2)%val = this%metrics(2)%val + acc_val
3479 this%metrics(1)%val = this%metrics(1)%val + loss_val
3480
3481 end do test_loop1
3482
3483
3484 ! Normalise metrics by number of samples
3485 !---------------------------------------------------------------------------
3486 this%accuracy_val = this%metrics(2)%val / real(num_samples, real32)
3487 this%loss_val = this%metrics(1)%val / real(num_samples, real32)
3488
3489 end subroutine test
3490 !###############################################################################
3491
3492
3493 !###############################################################################
3494 module function predict_real( &
3495 this, input, verbose &
3496 ) result(output)
3497 !! Predict the output for a 1D input
3498 implicit none
3499
3500 ! Arguments
3501 class(network_type), intent(inout) :: this
3502 !! Instance of network
3503 real(real32), dimension(..), intent(in) :: input
3504 !! Input
3505 integer, optional, intent(in) :: verbose
3506 !! Verbosity level
3507
3508 ! Local variables
3509 integer :: l
3510 !! Loop index
3511 real(real32), dimension(:,:), allocatable :: output
3512 !! Output
3513 integer :: verbose_, batch_size
3514 !! Verbosity level
3515
3516
3517 !---------------------------------------------------------------------------
3518 ! Initialise optional arguments
3519 !---------------------------------------------------------------------------
3520 if(present(verbose))then
3521 verbose_ = verbose
3522 else
3523 verbose_ = 0
3524 end if
3525
3526 select rank(input)
3527 rank(2)
3528 batch_size = size(input,dim=2)
3529 rank(3)
3530 batch_size = size(input,dim=3)
3531 rank(4)
3532 batch_size = size(input,dim=4)
3533 rank(5)
3534 batch_size = size(input,dim=5)
3535 rank(6)
3536 batch_size = size(input,dim=6)
3537 rank default
3538 batch_size = size(input,dim=rank(input))
3539 end select
3540
3541
3542 !---------------------------------------------------------------------------
3543 ! Reset batch size for testing
3544 !---------------------------------------------------------------------------
3545 call this%set_batch_size(batch_size)
3546
3547
3548 !---------------------------------------------------------------------------
3549 ! Turn on inference booleans
3550 !---------------------------------------------------------------------------
3551 do l = 1, this%num_layers
3552 this%model(l)%layer%inference = .true.
3553 end do
3554
3555
3556 !---------------------------------------------------------------------------
3557 ! Predict
3558 !---------------------------------------------------------------------------
3559 call this%forward(get_sample(input, 1, batch_size, batch_size))
3560
3561 output = this%model(this%leaf_vertices(1))%layer%output(1,1)%val
3562
3563 end function predict_real
3564 !###############################################################################
3565
3566
3567 !###############################################################################
3568 module function predict_graph1d( this, input, verbose ) result(output)
3569 !! Predict the output for a graph input
3570 implicit none
3571
3572 ! Arguments
3573 class(network_type), intent(inout) :: this
3574 !! Instance of network
3575 type(graph_type), dimension(:), intent(in) :: input
3576 !! Input graph
3577 integer, optional, intent(in) :: verbose
3578 !! Verbosity level
3579
3580 ! Local variables
3581 integer :: l, s
3582 !! Loop index
3583 type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: output
3584 !! Output graph
3585 integer :: verbose_ = 0, batch_size
3586 !! Verbosity level
3587
3588
3589 !---------------------------------------------------------------------------
3590 ! Initialise optional arguments
3591 !---------------------------------------------------------------------------
3592 if(present(verbose)) verbose_ = verbose
3593
3594 !---------------------------------------------------------------------------
3595 ! Reset batch size for testing
3596 !---------------------------------------------------------------------------
3597 batch_size = size(input)
3598 call this%set_batch_size(batch_size)
3599
3600
3601 !---------------------------------------------------------------------------
3602 ! Turn on inference booleans
3603 !---------------------------------------------------------------------------
3604 do l = 1, this%num_layers
3605 this%model(l)%layer%inference = .true.
3606 end do
3607
3608
3609 !---------------------------------------------------------------------------
3610 ! Predict
3611 !---------------------------------------------------------------------------
3612 call this%forward(get_sample(input, 1, batch_size, batch_size))
3613
3614 do l = 1, size(this%leaf_vertices)
3615 do s = 1, batch_size
3616 output(l,s)%num_vertices = input(s)%num_vertices
3617 output(l,s)%num_edges = input(s)%num_edges
3618 output(l,s)%num_vertex_features = this%model( &
3619 this%leaf_vertices(l) &
3620 )%layer%output_shape(1)
3621 output(l,s)%num_edge_features = this%model( &
3622 this%leaf_vertices(l) &
3623 )%layer%output_shape(2)
3624 output(l,s)%vertex_features = this%model( &
3625 this%leaf_vertices(l) &
3626 )%layer%output(1,s)%val
3627 if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
3628 output(l,s)%edge_features = input(s)%edge_features
3629 else
3630 output(l,s)%edge_features = this%model( &
3631 this%leaf_vertices(l) &
3632 )%layer%output(2,s)%val
3633 end if
3634 end do
3635 end do
3636
3637 end function predict_graph1d
3638 !-------------------------------------------------------------------------------
3639 module function predict_graph2d( this, input, verbose ) result(output)
3640 !! Predict the output for a graph input
3641 implicit none
3642
3643 ! Arguments
3644 class(network_type), intent(inout) :: this
3645 !! Instance of network
3646 type(graph_type), dimension(:,:), intent(in) :: input
3647 !! Input graph
3648 integer, optional, intent(in) :: verbose
3649 !! Verbosity level
3650
3651 ! Local variables
3652 integer :: l, s
3653 !! Loop index
3654 type(graph_type), dimension(size(this%leaf_vertices),size(input,dim=2)) :: &
3655 output
3656 !! Output graph
3657 integer :: verbose_ = 0, batch_size
3658 !! Verbosity level
3659
3660
3661 !---------------------------------------------------------------------------
3662 ! Initialise optional arguments
3663 !---------------------------------------------------------------------------
3664 if(present(verbose)) verbose_ = verbose
3665
3666 !---------------------------------------------------------------------------
3667 ! Reset batch size for testing
3668 !---------------------------------------------------------------------------
3669 batch_size = size(input, 2)
3670 call this%set_batch_size(batch_size)
3671
3672
3673 !---------------------------------------------------------------------------
3674 ! Turn on inference booleans
3675 !---------------------------------------------------------------------------
3676 do l = 1, this%num_layers
3677 this%model(l)%layer%inference = .true.
3678 end do
3679
3680
3681 !---------------------------------------------------------------------------
3682 ! Predict
3683 !---------------------------------------------------------------------------
3684 call this%forward(get_sample(input, 1, batch_size, batch_size))
3685
3686 do l = 1, size(this%leaf_vertices)
3687 do s = 1, batch_size
3688 output(l,s)%num_vertices = input(1,s)%num_vertices
3689 output(l,s)%num_edges = input(1,s)%num_edges
3690 output(l,s)%num_vertex_features = this%model( &
3691 this%leaf_vertices(l) &
3692 )%layer%output_shape(1)
3693 output(l,s)%num_edge_features = this%model( &
3694 this%leaf_vertices(l) &
3695 )%layer%output_shape(2)
3696 output(l,s)%vertex_features = this%model( &
3697 this%leaf_vertices(l) &
3698 )%layer%output(1,s)%val
3699 if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
3700 output(l,s)%edge_features = input(1,s)%edge_features
3701 else
3702 output(l,s)%edge_features = this%model( &
3703 this%leaf_vertices(l) &
3704 )%layer%output(2,s)%val
3705 end if
3706 end do
3707 end do
3708
3709 end function predict_graph2d
3710 !###############################################################################
3711
3712
3713 !###############################################################################
3714 module function predict_array_from_real( this, input, output_as_array, verbose ) &
3715 result(output)
3716 !! Predict the output for a generic input
3717 implicit none
3718
3719 ! Arguments
3720 class(network_type), intent(inout) :: this
3721 !! Instance of network
3722 class(*), dimension(..), intent(in) :: input
3723 !! Input graph
3724 logical, intent(in) :: output_as_array
3725 !! Whether to output as array
3726 integer, intent(in), optional :: verbose
3727 !! Verbosity level
3728
3729 type(array_type), dimension(:,:), allocatable :: output
3730 !! Predicted output
3731
3732 ! Local variables
3733 integer :: l, s, i
3734 !! Loop index
3735 integer :: num_samples
3736 !! Number of samples
3737 integer :: verbose_
3738 !! Verbosity level
3739 logical, dimension(:), allocatable :: inference_store
3740
3741
3742 !---------------------------------------------------------------------------
3743 ! Initialise optional arguments
3744 !---------------------------------------------------------------------------
3745 if(present(verbose))then
3746 verbose_ = verbose
3747 else
3748 verbose_ = 0
3749 end if
3750 if(.not.output_as_array)then
3751 call stop_program("predict_array_from_real: output_as_array must be true")
3752 return
3753 end if
3754
3755
3756 !---------------------------------------------------------------------------
3757 ! Set number of samples for predicting
3758 !---------------------------------------------------------------------------
3759 num_samples = this%save_input( input )
3760 ! call this%set_batch_size(num_samples)
3761
3762
3763 !---------------------------------------------------------------------------
3764 ! Turn on inference booleans
3765 !---------------------------------------------------------------------------
3766 allocate(inference_store(this%num_layers))
3767 do l = 1, this%num_layers
3768 inference_store(l) = this%model(l)%layer%inference
3769 this%model(l)%layer%inference = .true.
3770 end do
3771
3772 !---------------------------------------------------------------------------
3773 ! Forward pass
3774 !---------------------------------------------------------------------------
3775 select case(this%use_graph_input)
3776 case(.true.)
3777 call this%forward(this%input_graph)
3778 case default
3779 call this%forward(this%input_array)
3780 end select
3781
3782
3783 !---------------------------------------------------------------------------
3784 ! Allocate output data
3785 !---------------------------------------------------------------------------
3786 allocate(output( &
3787 size(this%model(this%leaf_vertices(1))%layer%output, 1), &
3788 size(this%model(this%leaf_vertices(1))%layer%output, 2) &
3789 ))
3790 do s = 1, size(this%model(this%leaf_vertices(1))%layer%output, 2)
3791 do i = 1, size(this%model(this%leaf_vertices(1))%layer%output, 1)
3792 output(i,s) = this%model(this%leaf_vertices(1))%layer%output(i,s)
3793 end do
3794 end do
3795
3796 !---------------------------------------------------------------------------
3797 ! Reset inference booleans
3798 !---------------------------------------------------------------------------
3799 do l = 1, this%num_layers
3800 this%model(l)%layer%inference = inference_store(l)
3801 end do
3802
3803 end function predict_array_from_real
3804 !###############################################################################
3805
3806
3807 !###############################################################################
3808 module function predict_array( this, input, verbose ) &
3809 result(output)
3810 !! Predict the output for a generic input
3811 implicit none
3812
3813 ! Arguments
3814 class(network_type), intent(inout) :: this
3815 !! Instance of network
3816 class(array_type), dimension(..), intent(in) :: input
3817 !! Input graph
3818 integer, intent(in), optional :: verbose
3819 !! Verbosity level
3820
3821 type(array_type), dimension(:,:), allocatable :: output
3822 !! Predicted output
3823
3824 ! Local variables
3825 integer :: l, s, i, j, layer_id
3826 !! Loop index
3827 integer :: num_samples
3828 !! Number of samples
3829 integer :: verbose_
3830 !! Verbosity level
3831 integer, dimension(2) :: output_shape
3832 !! Output shape
3833 logical, dimension(:), allocatable :: inference_store
3834 !! Inference store
3835
3836
3837 !---------------------------------------------------------------------------
3838 ! Initialise optional arguments
3839 !---------------------------------------------------------------------------
3840 if(present(verbose))then
3841 verbose_ = verbose
3842 else
3843 verbose_ = 0
3844 end if
3845
3846
3847 !---------------------------------------------------------------------------
3848 ! Set number of samples for predicting
3849 !---------------------------------------------------------------------------
3850 num_samples = this%save_input( input )
3851 ! call this%set_batch_size(num_samples)
3852
3853
3854 !---------------------------------------------------------------------------
3855 ! Turn on inference booleans
3856 !---------------------------------------------------------------------------
3857 allocate(inference_store(this%num_layers))
3858 do l = 1, this%num_layers
3859 inference_store(l) = this%model(l)%layer%inference
3860 this%model(l)%layer%inference = .true.
3861 end do
3862
3863 !---------------------------------------------------------------------------
3864 ! Forward pass
3865 !---------------------------------------------------------------------------
3866 select case(this%use_graph_input)
3867 case(.true.)
3868 call this%forward(this%input_graph)
3869 case default
3870 call this%forward(this%input_array)
3871 end select
3872
3873
3874 !---------------------------------------------------------------------------
3875 ! Allocate output data
3876 !---------------------------------------------------------------------------
3877 output_shape = this%get_output_shape()
3878 allocate(output(output_shape(1), output_shape(2)))
3879 do l = 1, size(this%leaf_vertices)
3880 layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id
3881 j = 0
3882 do i = 1, size(this%model(layer_id)%layer%output, 1)
3883 j = j + 1
3884 do s = 1, size(this%model(layer_id)%layer%output, 2)
3885 output(j,s) = this%model(layer_id)%layer%output(i,s)
3886 end do
3887 end do
3888 end do
3889
3890 !---------------------------------------------------------------------------
3891 ! Reset inference booleans
3892 !---------------------------------------------------------------------------
3893 do l = 1, this%num_layers
3894 this%model(l)%layer%inference = inference_store(l)
3895 end do
3896
3897 end function predict_array
3898 !###############################################################################
3899
3900
3901 !###############################################################################
3902 module function predict_generic( this, input, verbose, output_as_graph ) &
3903 result(output)
3904 !! Predict the output for a generic input
3905 implicit none
3906
3907 ! Arguments
3908 class(network_type), intent(inout) :: this
3909 !! Instance of network
3910 class(*), dimension(:,:), intent(in) :: input
3911 !! Input graph
3912 integer, intent(in), optional :: verbose
3913 !! Verbosity level
3914 logical, intent(in), optional :: output_as_graph
3915 !! Boolean whether to output as graph
3916
3917 class(*), dimension(:,:), allocatable :: output
3918 !! Predicted output
3919
3920 ! Local variables
3921 integer :: l, s, i, j, layer_id
3922 !! Loop index
3923 integer :: num_samples
3924 !! Number of samples
3925 integer :: verbose_
3926 !! Verbosity level
3927 logical :: output_as_graph_
3928 !! Output as graph boolean
3929 integer, dimension(2) :: output_shape
3930 !! Output shape
3931
3932
3933 !---------------------------------------------------------------------------
3934 ! Initialise optional arguments
3935 !---------------------------------------------------------------------------
3936 if(present(verbose))then
3937 verbose_ = verbose
3938 else
3939 verbose_ = 0
3940 end if
3941
3942 if(present(output_as_graph))then
3943 output_as_graph_ = output_as_graph
3944 else
3945 output_as_graph_ = .false.
3946 end if
3947 if(output_as_graph_.and..not.this%use_graph_output)then
3948 call stop_program("output_as_graph is true but network does not use &
3949 &graph output")
3950 end if
3951
3952
3953 !---------------------------------------------------------------------------
3954 ! Set number of samples for predicting
3955 !---------------------------------------------------------------------------
3956 num_samples = this%save_input( input )
3957 call this%set_batch_size(num_samples)
3958
3959
3960 !---------------------------------------------------------------------------
3961 ! Turn on inference booleans
3962 !---------------------------------------------------------------------------
3963 do l = 1, this%num_layers
3964 this%model(l)%layer%inference = .true.
3965 end do
3966
3967 !---------------------------------------------------------------------------
3968 ! Forward pass
3969 !---------------------------------------------------------------------------
3970 select case(this%use_graph_input)
3971 case(.true.)
3972 call this%forward(this%input_graph)
3973 case default
3974 call this%forward(this%input_array)
3975 end select
3976
3977
3978 !---------------------------------------------------------------------------
3979 ! Allocate output data
3980 !---------------------------------------------------------------------------
3981 output_shape = this%get_output_shape()
3982 if(output_as_graph_)then
3983 allocate(output(output_shape(1), output_shape(2)), source = graph_type())
3984 select type(output)
3985 type is(graph_type)
3986 select type(input)
3987 type is(graph_type)
3988 do l = 1, size(this%leaf_vertices)
3989 do s = 1, num_samples
3990 output(l,s)%num_vertices = input(1,s)%num_vertices
3991 output(l,s)%num_edges = input(1,s)%num_edges
3992 output(l,s)%num_vertex_features = this%model( &
3993 this%leaf_vertices(l) &
3994 )%layer%output_shape(1)
3995 output(l,s)%num_edge_features = this%model( &
3996 this%leaf_vertices(l) &
3997 )%layer%output_shape(2)
3998 output(l,s)%vertex_features = this%model( &
3999 this%leaf_vertices(l) &
4000 )%layer%output(1,s)%val
4001 if(size(this%model(this%leaf_vertices(l))%layer%output,1).eq.1)then
4002 output(l,s)%edge_features = input(1,s)%edge_features
4003 else
4004 output(l,s)%edge_features = this%model( &
4005 this%leaf_vertices(l) &
4006 )%layer%output(2,s)%val
4007 end if
4008 end do
4009 end do
4010 class default
4011 call stop_program("input is not of type graph_type")
4012 end select
4013 class default
4014 call stop_program("allocation of output as graph_type failed")
4015 end select
4016 else
4017 output_shape = this%get_output_shape()
4018 allocate(output(output_shape(1), output_shape(2)), source = array_type())
4019 select type(output)
4020 type is(array_type)
4021 do l = 1, size(this%leaf_vertices)
4022 layer_id = this%auto_graph%vertex(this%leaf_vertices(l))%id
4023 j = 0
4024 do i = 1, size(this%model(layer_id)%layer%output, 1)
4025 j = j + 1
4026 do s = 1, size(this%model(layer_id)%layer%output, 2)
4027 output(j,s) = this%model(layer_id)%layer%output(i,s)
4028 end do
4029 end do
4030 end do
4031 end select
4032 end if
4033
4034 end function predict_generic
4035 !###############################################################################
4036
4037
4038 !###############################################################################
4039 module subroutine print_summary(this)
4040 !! Print a summary of the network architecture
4041 implicit none
4042
4043 ! Arguments
4044 class(network_type), intent(in) :: this
4045 !! Instance of network
4046
4047 ! Local variables
4048 integer :: i, vertex_idx
4049 !! Loop index and vertex index
4050 integer :: total_params
4051 !! Parameter counts
4052 integer :: layer_params
4053 !! Parameters in current layer
4054 character(len=80) :: line
4055 !! Line separator
4056 character(len=40) :: layer_name
4057 !! Layer name
4058 character(len=30) :: output_shape_str
4059 !! Output shape string
4060 character(len=20) :: param_str
4061 !! Parameter count string
4062 character(len=100) :: fmt
4063 !! Format string
4064
4065 line = repeat('_', 80)
4066
4067 ! Print header
4068 write(*,*)
4069 write(*,'(A)') line
4070 write(*,'(A)') 'Model Summary'
4071 write(*,'(A)') line
4072 write(*,'(A35, A25, A15)') 'Layer (type)', 'Output Shape', 'Param #'
4073 write(*,'(A)') repeat('=', 80)
4074
4075 ! Initialise parameter count
4076 total_params = 0
4077
4078 ! Print each layer
4079 do i = 1, this%num_layers
4080 vertex_idx = this%vertex_order(i)
4081 associate(layer => this%model(vertex_idx)%layer)
4082 ! Get layer name
4083 if(allocated(layer%name))then
4084 write(layer_name, '(A," (",A,")")') &
4085 trim(layer%name), trim(layer%subtype)
4086 else
4087 write(layer_name, '(A,I0," (",A,")")') &
4088 'layer_', i, trim(layer%subtype)
4089 end if
4090
4091 ! Get output shape string
4092 if(allocated(layer%output_shape))then
4093 ! write the general format for output shape
4094 write(fmt,'("(""(""",A,"I0,"")"")")') &
4095 repeat('I0,", "', size(layer%output_shape)-1)
4096 write(output_shape_str, fmt) layer%output_shape
4097 else
4098 output_shape_str = '(Not set)'
4099 end if
4100
4101 ! Get parameter count
4102 layer_params = layer%get_num_params()
4103 total_params = total_params + layer_params
4104 if(layer_params > 0)then
4105 write(param_str, '(I0)') layer_params
4106 else
4107 param_str = '0'
4108 end if
4109
4110 ! Print layer information
4111 write(*,'(A35, A25, A15)') adjustl(trim(layer_name)), &
4112 adjustl(trim(output_shape_str)), adjustl(trim(param_str))
4113 end associate
4114 end do
4115
4116 ! Print footer
4117 write(*,'(A)') repeat('=', 80)
4118 write(*,'(A,I0)') 'Number of input vertices: ', size(this%root_vertices)
4119 write(*,'(A,I0)') 'Number of output vertices: ', size(this%leaf_vertices)
4120 write(*,'(A,I0)') 'Total trainable params: ', total_params
4121 write(*,'(A)') line
4122 write(*,*)
4123
4124 end subroutine print_summary
4125 !###############################################################################
4126
4127 end submodule athena__network_submodule
4128