GCC Code Coverage Report


Directory: src/athena/
File: athena_network.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__network
2 !! Module containing the network class used to define a neural network
3 !!
4 !! This module contains the types and interfaces for the network class used
5 !! to define a neural network.
6 !! The network class is used to define a neural network with overloaded
7 !! procedures for training, testing, predicting, and updating the network.
8 !! The network class is also used to define the network structure and
9 !! compile the network with an optimiser, loss function, and accuracy
10 !! function.
11 use coreutils, only: real32
12 use graphstruc, only: graph_type
13 use athena__metrics, only: metric_dict_type
14 use athena__optimiser, only: base_optimiser_type
15 use athena__loss, only: base_loss_type
16 use athena__accuracy, only: comp_acc_func => compute_accuracy_function
17 use athena__base_layer, only: base_layer_type
18 use diffstruc, only: array_type
19 use athena__misc_types, only: &
20 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
21 use athena__container_layer, only: container_layer_type
22 use athena__diffstruc_extd, only: array_ptr_type
23 implicit none
24
25
26 private
27
28 public :: network_type
29
30
31 type :: network_type
32 !! Type for defining a neural network with overloaded procedures
33 character(len=:), allocatable :: name
34 !! Name of the network
35 real(real32) :: accuracy_val, loss_val
36 !! Accuracy and loss of the network
37 integer :: batch_size = 0
38 !! Batch size
39 integer :: epoch = 0
40 !! Epoch number
41 integer :: num_layers = 0
42 !! Number of layers
43 integer :: num_outputs = 0
44 !! Number of outputs
45 integer :: num_params = 0
46 !! Number of parameters
47 logical :: use_graph_input = .false.
48 !! Boolean flag for graph input
49 logical :: use_graph_output = .false.
50 !! Boolean flag for graph output
51 class(base_optimiser_type), allocatable :: optimiser
52 !! Optimiser for the network
53 class(base_loss_type), allocatable :: loss
54 !! Loss method for the network
55 type(metric_dict_type), dimension(2) :: metrics
56 !! Metrics for the network
57 type(container_layer_type), allocatable, dimension(:) :: model
58 !! Model layers
59 character(len=:), allocatable :: loss_method, accuracy_method
60 !! Loss and accuracy method names
61 procedure(comp_acc_func), nopass, pointer :: get_accuracy => null()
62 !! Pointer to accuracy function
63 integer, dimension(:), allocatable :: vertex_order
64 !! Order of vertices
65 integer, dimension(:), allocatable :: root_vertices, leaf_vertices
66 !! Root and output vertices
67 type(graph_type) :: auto_graph
68 !! Graph structure for the network
69
70 type(array_type), dimension(:,:), allocatable :: input_array
71 !! Input array for the network
72 type(graph_type), dimension(:,:), allocatable :: input_graph
73 !! Input graph for the network
74 type(array_type), dimension(:,:), allocatable :: expected_array
75 !! Expected output array for the network
76 contains
77 procedure, pass(this) :: print
78 !! Print the network to file
79 procedure, pass(this) :: print_summary
80 !! Print a summary of the network architecture
81 procedure, pass(this) :: read
82 !! Read the network from a file
83 procedure, pass(this), private :: read_network_settings
84 !! Read network settings from a file
85 procedure, pass(this), private :: read_optimiser_settings
86 !! Read optimiser settings from a file
87 procedure, pass(this) :: build_from_onnx
88 !! Build network from ONNX nodes and initialisers
89 procedure, pass(this) :: add
90 !! Add a layer to the network
91 procedure, pass(this) :: reset
92 !! Reset the network
93 procedure, pass(this) :: compile
94 !! Compile the network
95 procedure, pass(this) :: set_batch_size
96 !! Set batch size
97 procedure, pass(this) :: set_metrics
98 !! Set network metrics
99 procedure, pass(this) :: set_loss
100 !! Set network loss method
101 procedure, pass(this) :: set_accuracy
102 !! Set network accuracy method
103 procedure, pass(this) :: reset_state
104 !! Reset hidden state of recurrent layers
105
106 procedure, pass(this) :: save_input => save_input_to_network
107 !! Convert and save polymorphic input to array or graph
108 procedure, pass(this) :: save_output => save_output_to_network
109 !! Convert and save polymorphic output to array or graph
110
111 procedure, pass(this) :: layer_from_id
112 !! Get the layer of the network from its ID
113
114 procedure, pass(this) :: train
115 !! Train the network
116 procedure, pass(this) :: test
117 !! Test the network
118
119 procedure, pass(this) :: predict_real
120 !! Return predicted results from supplied inputs using the trained network
121 procedure, pass(this) :: predict_array_from_real
122 !! Return predicted results as array from supplied inputs using the trained network
123 procedure, pass(this) :: predict_graph1d, predict_graph2d
124 !! Return predicted results from supplied inputs using the trained network (graph input)
125 procedure, pass(this) :: predict_array
126 !! Predict array type output for a generic input
127 procedure, pass(this) :: predict_generic
128 !! Predict generic type output for a generic input
129 generic :: predict => &
130 predict_real, predict_graph1d, predict_graph2d, &
131 predict_array, predict_array_from_real
132 !! Predict function for different input types
133
134
135 procedure, pass(this), private :: dfs
136 !! Depth first search
137 procedure, pass(this), private :: build_vertex_order
138 !! Generate vertex order
139 procedure, pass(this), private :: build_root_vertices
140 !! Calculate root vertices
141 procedure, pass(this), private :: build_leaf_vertices
142 !! Calculate output vertices
143
144 procedure, pass(this) :: reduce => network_reduction
145 !! Reduce two networks down to one (i.e. add two networks - parallel)
146 procedure, pass(this) :: copy => network_copy
147 !! Copy a network
148
149 procedure, pass(this) :: get_num_params
150 !! Get number of learnable parameters in the network
151 procedure, pass(this) :: get_params
152 !! Get learnable parameters
153 procedure, pass(this) :: set_params
154 !! Set learnable parameters
155 procedure, pass(this) :: get_gradients
156 !! Get gradients of learnable parameters
157 procedure, pass(this) :: set_gradients
158 !! Set learnable parameter gradients
159 procedure, pass(this) :: reset_gradients
160 !! Reset learnable parameter gradients
161 procedure, pass(this) :: get_output
162 !! Get the output of the network
163 procedure, pass(this) :: get_output_shape
164 !! Get the output shape of the network
165 procedure, pass(this) :: extract_output => extract_output_real
166 !! Extract network output as real array (only works for single output layer models)
167
168 procedure, pass(this) :: forward => forward_generic2d
169 !! Forward pass for generic 2D input
170 procedure, pass(this) :: forward_eval
171 !! Forward pass and return pointer to output (only works for single output layer models)
172 procedure, pass(this) :: accuracy_eval
173 !! Get the accuracy for the output
174 procedure, pass(this) :: loss_eval
175 !! Get the loss for the output
176 procedure, pass(this) :: update
177 !! Update the learnable parameters of the network based on gradients
178
179 procedure, pass(this) :: nullify_graph
180 !! Nullify graph data in the network to free memory
181 end type network_type
182
183 interface network_type
184 !! Interface for setting up the network (network initialisation)
185 module function network_setup( &
186 layers, &
187 optimiser, loss_method, accuracy_method, &
188 metrics, batch_size &
189 ) result(network)
190 !! Set up the network
191 type(container_layer_type), dimension(:), intent(in) :: layers
192 !! Layers
193 class(base_optimiser_type), optional, intent(in) :: optimiser
194 !! Optimiser
195 class(*), optional, intent(in) :: loss_method
196 !! Loss method
197 character(*), optional, intent(in) :: accuracy_method
198 !! Accuracy method
199 class(*), dimension(..), optional, intent(in) :: metrics
200 !! Metrics
201 integer, optional, intent(in) :: batch_size
202 !! Batch size
203 type(network_type) :: network
204 !! Instance of the network
205 end function network_setup
206 end interface network_type
207
208 interface
209 !! Interface for printing the network to file
210 module subroutine print(this, file)
211 !! Print the network to file
212 class(network_type), intent(in) :: this
213 !! Instance of the network
214 character(*), intent(in) :: file
215 !! File name
216 end subroutine print
217
218 !! Interface for printing a summary of the network
219 module subroutine print_summary(this)
220 !! Print a summary of the network architecture
221 class(network_type), intent(in) :: this
222 !! Instance of the network
223 end subroutine print_summary
224
225 !! Interface for reading the network from a file
226 module subroutine read(this, file)
227 !! Read the network from a file
228 class(network_type), intent(inout) :: this
229 !! Instance of the network
230 character(*), intent(in) :: file
231 !! File name
232 end subroutine read
233
234 !! Interface for reading network settings from a file
235 module subroutine read_network_settings(this, unit)
236 !! Read network settings from a file
237 class(network_type), intent(inout) :: this
238 !! Instance of the network
239 integer, intent(in) :: unit
240 !! Unit number for input
241 end subroutine read_network_settings
242
243 !! Interface for reading optimiser settings from a file
244 module subroutine read_optimiser_settings(this, unit)
245 !! Read optimiser settings from a file
246 class(network_type), intent(inout) :: this
247 !! Instance of the network
248 integer, intent(in) :: unit
249 !! Unit number for input
250 end subroutine read_optimiser_settings
251
252 !! Interface for building network from ONNX nodes and initialisers
253 module subroutine build_from_onnx( &
254 this, nodes, initialisers, inputs, value_info, verbose &
255 )
256 !! Build network from ONNX nodes and initialisers
257 class(network_type), intent(inout) :: this
258 !! Instance of the network
259 type(onnx_node_type), dimension(:), intent(in) :: nodes
260 !! Array of ONNX nodes
261 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
262 !! Array of ONNX initialisers
263 type(onnx_tensor_type), dimension(:), intent(in) :: inputs
264 !! Array of ONNX input tensors
265 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
266 !! Array of ONNX value info tensors
267 integer, optional, intent(in) :: verbose
268 !! Verbosity level
269 end subroutine build_from_onnx
270
271 !! Interface for adding a layer to the network
272 module subroutine add(this, layer, input_list, output_list, operator)
273 !! Add a layer to the network
274 class(network_type), intent(inout) :: this
275 !! Instance of the network
276 class(base_layer_type), intent(in) :: layer
277 !! Layer to add
278 integer, dimension(:), intent(in), optional :: input_list, output_list
279 !! Input and output list
280 class(*), optional, intent(in) :: operator
281 !! Operator
282 end subroutine add
283
284 !! Interface for resetting the network
285 module subroutine reset(this)
286 !! Reset the network
287 class(network_type), intent(inout) :: this
288 !! Instance of the network
289 end subroutine reset
290
291 !! Interface for compiling the network
292 module subroutine compile( &
293 this, optimiser, loss_method, accuracy_method, &
294 metrics, batch_size, verbose &
295 )
296 !! Compile the network
297 class(network_type), intent(inout) :: this
298 !! Instance of the network
299 class(base_optimiser_type), optional, intent(in) :: optimiser
300 !! Optimiser
301 class(*), optional, intent(in) :: loss_method
302 !! Loss method
303 character(*), optional, intent(in) :: accuracy_method
304 !! Accuracy method
305 class(*), dimension(..), optional, intent(in) :: metrics
306 !! Metrics
307 integer, optional, intent(in) :: batch_size
308 !! Batch size
309 integer, optional, intent(in) :: verbose
310 !! Verbosity level
311 end subroutine compile
312
313 !! Interface for setting batch size
314 module subroutine set_batch_size(this, batch_size)
315 !! Set batch size
316 class(network_type), intent(inout) :: this
317 !! Instance of the network
318 integer, intent(in) :: batch_size
319 !! Batch size
320 end subroutine set_batch_size
321
322 !! Interface for setting network metrics
323 module subroutine set_metrics(this, metrics)
324 !! Set network metrics
325 class(network_type), intent(inout) :: this
326 !! Instance of the network
327 class(*), dimension(..), intent(in) :: metrics
328 !! Metrics
329 end subroutine set_metrics
330
331 !! Interface for setting network loss method
332 module subroutine set_loss(this, loss_method, verbose)
333 !! Set network loss method
334 class(network_type), intent(inout) :: this
335 !! Instance of the network
336 class(*), intent(in) :: loss_method
337 !! Loss method
338 integer, optional, intent(in) :: verbose
339 !! Verbosity level
340 end subroutine set_loss
341
342 !! Interface for setting network accuracy method
343 module subroutine set_accuracy(this, accuracy_method, verbose)
344 !! Set network accuracy method
345 class(network_type), intent(inout) :: this
346 !! Instance of the network
347 character(*), intent(in) :: accuracy_method
348 !! Accuracy method
349 integer, optional, intent(in) :: verbose
350 !! Verbosity level
351 end subroutine set_accuracy
352
353 !! Interface for resetting state of recurrent layers
354 module subroutine reset_state(this)
355 !! Reset hidden state of recurrent layers
356 class(network_type), intent(inout) :: this
357 !! Instance of the network
358 end subroutine reset_state
359
360 !! Interface for saving input to network
361 module function save_input_to_network( this, input ) result(num_samples)
362 !! Convert and save polymorphic input to array or graph
363 class(network_type), intent(inout) :: this
364 !! Instance of network
365 class(*), dimension(..), intent(in) :: input
366 !! Input
367 integer :: num_samples
368 !! Number of samples
369 end function save_input_to_network
370
371 !! Interface for saving output to network
372 module subroutine save_output_to_network( this, output )
373 !! Convert and save polymorphic output to array or graph
374 class(network_type), intent(inout) :: this
375 !! Instance of network
376 class(*), dimension(:,:), intent(in) :: output
377 !! Output
378 end subroutine save_output_to_network
379
380 module function layer_from_id(this, id) result(layer)
381 !! Get the layer of the network from its ID
382 class(network_type), intent(in), target :: this
383 !! Instance of the network
384 integer, intent(in) :: id
385 !! Layer ID
386 class(base_layer_type), pointer :: layer
387 !! Layer pointer
388 end function layer_from_id
389
390
391 !! Interface for training the network
392 module subroutine train( &
393 this, input, output, num_epochs, batch_size, &
394 plateau_threshold, shuffle_batches, batch_print_step, verbose &
395 )
396 !! Train the network
397 class(network_type), intent(inout) :: this
398 !! Instance of the network
399 class(*), dimension(..), intent(in) :: input
400 !! Input data
401 class(*), dimension(:,:), intent(in) :: output
402 !! Expected output data (data labels)
403 integer, intent(in) :: num_epochs
404 !! Number of epochs to train for
405 integer, optional, intent(in) :: batch_size
406 !! Batch size (DEPRECATED)
407 real(real32), optional, intent(in) :: plateau_threshold
408 !! Threshold for checking learning plateau
409 logical, optional, intent(in) :: shuffle_batches
410 !! Shuffle batch order
411 integer, optional, intent(in) :: batch_print_step
412 !! Print step for batch
413 integer, optional, intent(in) :: verbose
414 !! Verbosity level
415 end subroutine train
416
417 !! Interface for testing the network
418 module subroutine test(this, input, output, verbose)
419 !! Test the network
420 class(network_type), intent(inout) :: this
421 !! Instance of the network
422 class(*), dimension(..), intent(in) :: input
423 !! Input data
424 class(*), dimension(:,:), intent(in) :: output
425 !! Expected output data (data labels)
426 integer, optional, intent(in) :: verbose
427 !! Verbosity level
428 end subroutine test
429
430 !! Interface for returning predicted results from supplied inputs
431 !! using the trained network
432 module function predict_real(this, input, verbose) result(output)
433 !! Get predicted results from supplied inputs using the trained network
434 class(network_type), intent(inout) :: this
435 !! Instance of the network
436 real(real32), dimension(..), intent(in) :: input
437 !! Input data
438 integer, optional, intent(in) :: verbose
439 !! Verbosity level
440 real(real32), dimension(:,:), allocatable :: output
441 !! Predicted output data
442 end function predict_real
443
444 module function predict_array_from_real( &
445 this, input, output_as_array, verbose &
446 ) result(output)
447 !! Get predicted results as array from supplied inputs using the trained network
448 class(network_type), intent(inout) :: this
449 !! Instance of the network
450 class(*), dimension(..), intent(in) :: input
451 !! Input data
452 logical, intent(in) :: output_as_array
453 !! Whether to output as array
454 integer, optional, intent(in) :: verbose
455 !! Verbosity level
456 type(array_type), dimension(:,:), allocatable :: output
457 !! Predicted output data as array
458 end function predict_array_from_real
459
460 !! Interface for returning predicted results from supplied inputs
461 !! using the trained network (graph input)
462 module function predict_graph1d(this, input, verbose) result(output)
463 !! Get predicted results from supplied inputs using the trained network
464 class(network_type), intent(inout) :: this
465 !! Instance of the network
466 type(graph_type), dimension(:), intent(in) :: input
467 !! Input data
468 integer, optional, intent(in) :: verbose
469 !! Verbosity level
470 type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: &
471 output
472 !! Predicted output data
473 end function predict_graph1d
474 module function predict_graph2d(this, input, verbose) result(output)
475 !! Get predicted results from supplied inputs using the trained network
476 class(network_type), intent(inout) :: this
477 !! Instance of the network
478 type(graph_type), dimension(:,:), intent(in) :: input
479 !! Input data
480 integer, optional, intent(in) :: verbose
481 !! Verbosity level
482 type(graph_type), dimension(size(this%leaf_vertices),size(input, 2)) :: &
483 output
484 !! Predicted output data
485 end function predict_graph2d
486
487 module function predict_array( this, input, verbose ) &
488 result(output)
489 !! Predict the output for a generic input
490 class(network_type), intent(inout) :: this
491 !! Instance of network
492 class(array_type), dimension(..), intent(in) :: input
493 !! Input graph
494 integer, intent(in), optional :: verbose
495 !! Verbosity level
496 type(array_type), dimension(:,:), allocatable :: output
497 end function predict_array
498
499 module function predict_generic( this, input, verbose, output_as_graph ) &
500 result(output)
501 !! Predict the output for a generic input
502 class(network_type), intent(inout) :: this
503 !! Instance of network
504 class(*), dimension(:,:), intent(in) :: input
505 !! Input graph
506 integer, intent(in), optional :: verbose
507 !! Verbosity level
508 logical, intent(in), optional :: output_as_graph
509 !! Boolean whether to output as graph
510 class(*), dimension(:,:), allocatable :: output
511 end function predict_generic
512
513 !! Interface for updating the learnable parameters of the network
514 !! based on gradients
515 module subroutine update(this)
516 !! Update the learnable parameters of the network based on gradients
517 class(network_type), intent(inout) :: this
518 !! Instance of the network
519 end subroutine update
520
521 !! Interface for generating vertex order
522 module subroutine build_vertex_order(this)
523 !! Generate vertex order
524 class(network_type), intent(inout) :: this
525 !! Instance of the network
526 end subroutine build_vertex_order
527
528 !! Interface for depth first search
529 recursive module subroutine dfs( &
530 this, vertex_index, visited, order, order_index &
531 )
532 !! Depth first search
533 class(network_type), intent(in) :: this
534 !! Instance of the network
535 integer, intent(in) :: vertex_index
536 !! Vertex index
537 logical, dimension(this%auto_graph%num_vertices), intent(inout) :: &
538 visited
539 !! Visited vertices
540 integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order
541 !! Order of vertices
542 integer, intent(inout) :: order_index
543 !! Index of order
544 end subroutine dfs
545
546 !! Interface for calculating root vertices
547 module subroutine build_root_vertices(this)
548 !! Calculate root vertices
549 class(network_type), intent(inout) :: this
550 !! Instance of the network
551 end subroutine build_root_vertices
552
553 !! Interface for calculating output vertices
554 module subroutine build_leaf_vertices(this)
555 !! Calculate output vertices
556 class(network_type), intent(inout) :: this
557 !! Instance of the network
558 end subroutine build_leaf_vertices
559
560 !! Interface for reducing two networks down to one
561 !! (i.e. add two networks - parallel)
562 module subroutine network_reduction(this, source)
563 !! Reduce two networks down to one (i.e. add two networks - parallel)
564 class(network_type), intent(inout) :: this
565 !! Instance of the network
566 type(network_type), intent(in) :: source
567 !! Source network
568 end subroutine network_reduction
569
570 !! Interface for copying a network
571 module subroutine network_copy(this, source)
572 !! Copy a network
573 class(network_type), intent(inout) :: this
574 !! Instance of the network
575 type(network_type), intent(in), target :: source
576 !! Source network
577 end subroutine network_copy
578
579 !! Interface for getting number of learnable parameters in the network
580 pure module function get_num_params(this) result(num_params)
581 !! Get number of learnable parameters in the network
582 class(network_type), intent(in) :: this
583 !! Instance of the network
584 integer :: num_params
585 !! Number of parameters
586 end function get_num_params
587
588 !! Interface for getting learnable parameters
589 pure module function get_params(this) result(params)
590 !! Get learnable parameters
591 class(network_type), intent(in) :: this
592 !! Instance of the network
593 real(real32), dimension(this%num_params) :: params
594 !! Learnable parameters
595 end function get_params
596
597 !! Interface for setting learnable parameters
598 module subroutine set_params(this, params)
599 !! Set learnable parameters
600 class(network_type), intent(inout) :: this
601 !! Instance of the network
602 real(real32), dimension(this%num_params), intent(in) :: params
603 !! Learnable parameters
604 end subroutine set_params
605
606 !! Interface for getting gradients of learnable parameters
607 pure module function get_gradients(this) result(gradients)
608 !! Get gradients of learnable parameters
609 class(network_type), intent(in) :: this
610 !! Instance of the network
611 real(real32), dimension(this%num_params) :: gradients
612 !! Gradients
613 end function get_gradients
614
615 !! Interface for setting learnable parameter gradients
616 module subroutine set_gradients(this, gradients)
617 !! Set learnable parameter gradients
618 class(network_type), intent(inout) :: this
619 !! Instance of the network
620 real(real32), dimension(..), intent(in) :: gradients
621 !! Gradients
622 end subroutine set_gradients
623
624 !! Interface for resetting learnable parameter gradients
625 module subroutine reset_gradients(this)
626 !! Reset learnable parameter gradients
627 class(network_type), intent(inout) :: this
628 !! Instance of the network
629 end subroutine reset_gradients
630
631 module function get_output(this) result(output)
632 class(network_type), intent(in) :: this
633 !! Instance of the network
634 type(array_type), dimension(:,:), allocatable :: output
635 !! Output
636 end function get_output
637
638 module function get_output_shape(this) result(output_shape)
639 class(network_type), intent(in) :: this
640 !! Instance of the network
641 integer, dimension(2) :: output_shape
642 !! Output shape
643 end function get_output_shape
644
645 module subroutine extract_output_real(this, output)
646 class(network_type), intent(in) :: this
647 !! Instance of network
648 real(real32), dimension(..), allocatable, intent(out) :: output
649 !! Output
650 end subroutine extract_output_real
651
652 module function accuracy_eval(this, output, start_index, end_index) &
653 result(accuracy)
654 !! Get the accuracy for the output
655 class(network_type), intent(in) :: this
656 !! Instance of network
657 class(*), dimension(:,:), intent(in) :: output
658 !! Output
659 integer, intent(in) :: start_index, end_index
660 !! Start and end batch indices
661 real(real32) :: accuracy
662 !! Accuracy value
663 end function accuracy_eval
664
665 module function loss_eval(this, start_index, end_index) result(loss)
666 !! Get the loss for the output
667 ! Arguments
668 class(network_type), intent(inout), target :: this
669 !! Instance of network
670 integer, intent(in) :: start_index, end_index
671 !! Start and end batch indices
672
673 type(array_type), pointer :: loss
674 end function loss_eval
675
676 !! Interface for forward pass
677 module subroutine forward_generic2d(this, input)
678 !! Forward pass for generic 2D input
679 class(network_type), intent(inout), target :: this
680 !! Instance of the network
681 class(*), dimension(:,:), intent(in) :: input
682 !! Input data
683 end subroutine forward_generic2d
684
685 module function forward_eval(this, input) result(output)
686 !! Forward pass evaluation
687 class(network_type), intent(inout), target :: this
688 !! Instance of the network
689 class(*), dimension(:,:), intent(in) :: input
690 !! Input data
691 type(array_type), pointer :: output(:,:)
692 !! Output data
693 end function forward_eval
694
695 module function forward_eval_multi(this, input) result(output)
696 !! Forward pass evaluation for multiple outputs
697 class(network_type), intent(inout), target :: this
698 !! Instance of the network
699 class(*), dimension(:,:), intent(in) :: input
700 !! Input data
701 type(array_ptr_type), pointer :: output(:)
702 !! Output data
703 end function forward_eval_multi
704
705 module subroutine nullify_graph(this)
706 !! Nullify graph data in the network to free memory
707 class(network_type), intent(inout) :: this
708 !! Instance of the network
709 end subroutine nullify_graph
710 end interface
711
712 interface get_sample
713 module function get_sample_ptr( &
714 input, start_index, end_index, batch_size &
715 ) result(sample_ptr)
716 !! Get a sample from a rank
717 implicit none
718 ! Arguments
719 integer, intent(in) :: start_index, end_index
720 !! Start and end indices
721 integer, intent(in) :: batch_size
722 !! Batch size
723 real(real32), dimension(..), intent(in), target :: input
724 !! Input array
725 ! Local variables
726 real(real32), pointer :: sample_ptr(:,:)
727 !! Pointer to sample
728 end function get_sample_ptr
729 module function get_sample_array( &
730 input, start_index, end_index, batch_size, as_graph&
731 ) result(sample)
732 !! Get sample for mixed input
733 integer, intent(in) :: start_index, end_index
734 !! Start and end indices
735 integer, intent(in) :: batch_size
736 !! Batch size
737 class(array_type), dimension(:,:), intent(in) :: input
738 !! Input array
739 logical, intent(in) :: as_graph
740 !! Boolean whether to treat the input as a graph
741 type(array_type), dimension(:,:), allocatable :: sample
742 !! Sample array
743 end function get_sample_array
744 module function get_sample_graph1d( &
745 input, start_index, end_index, batch_size &
746 ) result(sample)
747 !! Get sample for graph input
748 integer, intent(in) :: start_index, end_index
749 !! Start and end indices
750 integer, intent(in) :: batch_size
751 !! Batch size
752 class(graph_type), dimension(:), intent(in) :: input
753 !! Input array
754 type(graph_type), dimension(1, batch_size) :: sample
755 !! Sample array
756 end function get_sample_graph1d
757 module function get_sample_graph2d( &
758 input, start_index, end_index, batch_size &
759 ) result(sample)
760 !! Get sample for graph input
761 integer, intent(in) :: start_index, end_index
762 !! Start and end indices
763 integer, intent(in) :: batch_size
764 !! Batch size
765 class(graph_type), dimension(:,:), intent(in) :: input
766 !! Input array
767 type(graph_type), dimension(size(input,1), batch_size) :: sample
768 !! Sample array
769 end function get_sample_graph2d
770 end interface get_sample
771
772 end module athena__network
773