GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_network.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 0 1 0.0%
Functions: 0 0 -%
Branches: 0 308 0.0%

Line Branch Exec Source
1 module athena__network
2 !! Module containing the network class used to define a neural network
3 !!
4 !! This module contains the types and interfaces for the network class used
5 !! to define a neural network.
6 !! The network class is used to define a neural network with overloaded
7 !! procedures for training, testing, predicting, and updating the network.
8 !! The network class is also used to define the network structure and
9 !! compile the network with an optimiser, loss function, and accuracy
10 !! function.
11 use coreutils, only: real32
12 use graphstruc, only: graph_type
13 use athena__metrics, only: metric_dict_type
14 use athena__optimiser, only: base_optimiser_type
15 use athena__loss, only: base_loss_type
16 use athena__accuracy, only: comp_acc_func => compute_accuracy_function
17 use athena__base_layer, only: base_layer_type
18 use diffstruc, only: array_type
19 use athena__misc_types, only: &
20 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
21 use athena__container_layer, only: container_layer_type
22 use athena__diffstruc_extd, only: array_ptr_type
23 implicit none
24
25
26 private
27
28 public :: network_type
29
30
31 type :: network_type
32 !! Type for defining a neural network with overloaded procedures
33 character(len=:), allocatable :: name
34 !! Name of the network
35 real(real32) :: accuracy_val, loss_val
36 !! Accuracy and loss of the network
37 integer :: batch_size = 0
38 !! Batch size
39 integer :: epoch = 0
40 !! Epoch number
41 integer :: num_layers = 0
42 !! Number of layers
43 integer :: num_outputs = 0
44 !! Number of outputs
45 integer :: num_params = 0
46 !! Number of parameters
47 logical :: use_graph_input = .false.
48 !! Boolean flag for graph input
49 logical :: use_graph_output = .false.
50 !! Boolean flag for graph output
51 class(base_optimiser_type), allocatable :: optimiser
52 !! Optimiser for the network
53 class(base_loss_type), allocatable :: loss
54 !! Loss method for the network
55 type(metric_dict_type), dimension(2) :: metrics
56 !! Metrics for the network
57 type(container_layer_type), allocatable, dimension(:) :: model
58 !! Model layers
59 character(len=:), allocatable :: loss_method, accuracy_method
60 !! Loss and accuracy method names
61 procedure(comp_acc_func), nopass, pointer :: get_accuracy => null()
62 !! Pointer to accuracy function
63 integer, dimension(:), allocatable :: vertex_order
64 !! Order of vertices
65 integer, dimension(:), allocatable :: root_vertices, leaf_vertices
66 !! Root and output vertices
67 type(graph_type) :: auto_graph
68 !! Graph structure for the network
69
70 ! Pre-computed forward pass navigation (populated during compile)
71 integer, dimension(:), allocatable :: fwd_layer_id
72 !! Layer ID for each vertex in forward order
73 integer, dimension(:), allocatable :: fwd_num_inputs
74 !! Number of input layers for each vertex in forward order
75 integer, dimension(:), allocatable :: fwd_parent_id
76 !! Parent layer ID for single-input vertices
77 integer, dimension(:), allocatable :: fwd_layer_type
78 !! Layer type: 0=input, 1=merge, 2=default
79
80 ! Pre-computed parameter segment layout (populated during compile)
81 integer :: param_num_segments = 0
82 !! Number of parameter segments
83 integer, dimension(:), allocatable :: param_seg_layer
84 !! Layer index for each parameter segment
85 integer, dimension(:), allocatable :: param_seg_pidx
86 !! Param index within that layer for each segment
87 integer, dimension(:), allocatable :: param_seg_start
88 !! Start offset in flat parameter array
89 integer, dimension(:), allocatable :: param_seg_end
90 !! End offset in flat parameter array
91
92 type(array_type), dimension(:,:), allocatable :: input_array
93 !! Input array for the network
94 type(graph_type), dimension(:,:), allocatable :: input_graph
95 !! Input graph for the network
96 type(array_type), dimension(:,:), allocatable :: expected_array
97 !! Expected output array for the network
98 contains
99 procedure, pass(this) :: print
100 !! Print the network to file
101 procedure, pass(this) :: print_summary
102 !! Print a summary of the network architecture
103 procedure, pass(this) :: read
104 !! Read the network from a file
105 procedure, pass(this), private :: read_network_settings
106 !! Read network settings from a file
107 procedure, pass(this), private :: read_optimiser_settings
108 !! Read optimiser settings from a file
109 procedure, pass(this) :: build_from_onnx
110 !! Build network from ONNX nodes and initialisers
111 procedure, pass(this) :: add
112 !! Add a layer to the network
113 procedure, pass(this) :: reset
114 !! Reset the network
115 procedure, pass(this) :: compile
116 !! Compile the network
117 procedure, pass(this) :: set_batch_size
118 !! Set batch size
119 procedure, pass(this) :: set_metrics
120 !! Set network metrics
121 procedure, pass(this) :: set_loss
122 !! Set network loss method
123 procedure, pass(this) :: set_accuracy
124 !! Set network accuracy method
125 procedure, pass(this) :: reset_state
126 !! Reset hidden state of recurrent layers
127 procedure, pass(this) :: set_training_mode
128 !! Set training mode for layers with training/inference-specific behaviour
129 procedure, pass(this) :: set_inference_mode
130 !! Set inference mode for layers with training/inference-specific behaviour
131 procedure, pass(this), private :: restore_mode
132 !! Reset the training/inference mode of layers to the values stored in mode_store.
133
134 procedure, pass(this) :: save_input => save_input_to_network
135 !! Convert and save polymorphic input to array or graph
136 procedure, pass(this) :: save_output => save_output_to_network
137 !! Convert and save polymorphic output to array or graph
138
139 procedure, pass(this) :: layer_from_id
140 !! Get the layer of the network from its ID
141
142 procedure, pass(this) :: train
143 !! Train the network
144 procedure, pass(this) :: test
145 !! Test the network
146
147 procedure, pass(this) :: predict_real
148 !! Return predicted results from supplied inputs using the trained network
149 procedure, pass(this) :: predict_array_from_real
150 !! Return predicted results as array from supplied inputs using the trained network
151 procedure, pass(this) :: predict_graph1d, predict_graph2d
152 !! Return predicted results from supplied inputs using the trained network (graph input)
153 procedure, pass(this) :: predict_array
154 !! Predict array type output for a generic input
155 procedure, pass(this) :: predict_generic
156 !! Predict generic type output for a generic input
157 generic :: predict => &
158 predict_real, predict_graph1d, predict_graph2d, &
159 predict_array, predict_array_from_real
160 !! Predict function for different input types
161
162
163 procedure, pass(this), private :: dfs
164 !! Depth first search
165 procedure, pass(this), private :: build_vertex_order
166 !! Generate vertex order
167 procedure, pass(this), private :: build_root_vertices
168 !! Calculate root vertices
169 procedure, pass(this), private :: build_leaf_vertices
170 !! Calculate output vertices
171
172 procedure, pass(this) :: reduce => network_reduction
173 !! Reduce two networks down to one (i.e. add two networks - parallel)
174 procedure, pass(this) :: copy => network_copy
175 !! Copy a network
176
177 procedure, pass(this) :: get_num_params
178 !! Get number of learnable parameters in the network
179 procedure, pass(this) :: get_params
180 !! Get learnable parameters
181 procedure, pass(this) :: set_params
182 !! Set learnable parameters
183 procedure, pass(this) :: get_gradients
184 !! Get gradients of learnable parameters
185 procedure, pass(this) :: set_gradients
186 !! Set learnable parameter gradients
187 procedure, pass(this) :: reset_gradients
188 !! Reset learnable parameter gradients
189 procedure, pass(this) :: get_output
190 !! Get the output of the network
191 procedure, pass(this) :: get_output_shape
192 !! Get the output shape of the network
193 procedure, pass(this) :: extract_output => extract_output_real
194 !! Extract network output as real array (only works for single output layer models)
195
196 procedure, pass(this) :: forward => forward_generic2d
197 !! Forward pass for generic 2D input
198 procedure, pass(this) :: forward_eval
199 !! Forward pass and return pointer to output (only works for single output layer models)
200 procedure, pass(this) :: accuracy_eval
201 !! Get the accuracy for the output
202 procedure, pass(this) :: loss_eval
203 !! Get the loss for the output
204 procedure, pass(this) :: update
205 !! Update the learnable parameters of the network based on gradients
206
207 procedure, pass(this) :: nullify_graph
208 !! Nullify graph data in the network to free memory
209
210 procedure, pass(this) :: post_epoch_hook
211 !! Called after each training epoch; override in derived types for custom
212 !! per-epoch callbacks (e.g. logging to Weights & Biases).
213
214 procedure, pass(this), private :: inverse_design_real
215 !! Inverse design with real inputs
216 procedure, pass(this), private :: inverse_design_array_0d
217 !! Inverse design with 0d array_type inputs
218 procedure, pass(this), private :: inverse_design_array_2d
219 !! Inverse design with 2d array_type inputs
220 generic :: inverse_design => &
221 inverse_design_real, inverse_design_array_0d, inverse_design_array_2d
222 !! Optimise input to match a target output
223 end type network_type
224
225 interface network_type
226 !! Interface for setting up the network (network initialisation)
227 module function network_setup( &
228 layers, &
229 optimiser, loss_method, accuracy_method, &
230 metrics, batch_size &
231 ) result(network)
232 !! Set up the network
233 type(container_layer_type), dimension(:), intent(in) :: layers
234 !! Layers
235 class(base_optimiser_type), optional, intent(in) :: optimiser
236 !! Optimiser
237 class(*), optional, intent(in) :: loss_method
238 !! Loss method
239 character(*), optional, intent(in) :: accuracy_method
240 !! Accuracy method
241 class(*), dimension(..), optional, intent(in) :: metrics
242 !! Metrics
243 integer, optional, intent(in) :: batch_size
244 !! Batch size
245 type(network_type) :: network
246 !! Instance of the network
247 end function network_setup
248 end interface network_type
249
250 interface
251 !! Interface for printing the network to file
252 module subroutine print(this, file)
253 !! Print the network to file
254 class(network_type), intent(in) :: this
255 !! Instance of the network
256 character(*), intent(in) :: file
257 !! File name
258 end subroutine print
259
260 !! Interface for printing a summary of the network
261 module subroutine print_summary(this)
262 !! Print a summary of the network architecture
263 class(network_type), intent(in) :: this
264 !! Instance of the network
265 end subroutine print_summary
266
267 !! Interface for reading the network from a file
268 module subroutine read(this, file)
269 !! Read the network from a file
270 class(network_type), intent(inout) :: this
271 !! Instance of the network
272 character(*), intent(in) :: file
273 !! File name
274 end subroutine read
275
276 !! Interface for reading network settings from a file
277 module subroutine read_network_settings(this, unit)
278 !! Read network settings from a file
279 class(network_type), intent(inout) :: this
280 !! Instance of the network
281 integer, intent(in) :: unit
282 !! Unit number for input
283 end subroutine read_network_settings
284
285 !! Interface for reading optimiser settings from a file
286 module subroutine read_optimiser_settings(this, unit)
287 !! Read optimiser settings from a file
288 class(network_type), intent(inout) :: this
289 !! Instance of the network
290 integer, intent(in) :: unit
291 !! Unit number for input
292 end subroutine read_optimiser_settings
293
294 !! Interface for building network from ONNX nodes and initialisers
295 module subroutine build_from_onnx( &
296 this, nodes, initialisers, inputs, value_info, verbose &
297 )
298 !! Build network from ONNX nodes and initialisers
299 class(network_type), intent(inout) :: this
300 !! Instance of the network
301 type(onnx_node_type), dimension(:), intent(in) :: nodes
302 !! Array of ONNX nodes
303 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
304 !! Array of ONNX initialisers
305 type(onnx_tensor_type), dimension(:), intent(in) :: inputs
306 !! Array of ONNX input tensors
307 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
308 !! Array of ONNX value info tensors
309 integer, optional, intent(in) :: verbose
310 !! Verbosity level
311 end subroutine build_from_onnx
312
313 !! Interface for adding a layer to the network
314 module subroutine add(this, layer, input_list, output_list, operator)
315 !! Add a layer to the network
316 class(network_type), intent(inout) :: this
317 !! Instance of the network
318 class(base_layer_type), intent(in) :: layer
319 !! Layer to add
320 integer, dimension(:), intent(in), optional :: input_list, output_list
321 !! Input and output list
322 class(*), optional, intent(in) :: operator
323 !! Operator
324 end subroutine add
325
326 !! Interface for resetting the network
327 module subroutine reset(this)
328 !! Reset the network
329 class(network_type), intent(inout) :: this
330 !! Instance of the network
331 end subroutine reset
332
333 !! Interface for compiling the network
334 module subroutine compile( &
335 this, optimiser, loss_method, accuracy_method, &
336 metrics, batch_size, verbose &
337 )
338 !! Compile the network
339 class(network_type), intent(inout) :: this
340 !! Instance of the network
341 class(base_optimiser_type), optional, intent(in) :: optimiser
342 !! Optimiser
343 class(*), optional, intent(in) :: loss_method
344 !! Loss method
345 character(*), optional, intent(in) :: accuracy_method
346 !! Accuracy method
347 class(*), dimension(..), optional, intent(in) :: metrics
348 !! Metrics
349 integer, optional, intent(in) :: batch_size
350 !! Batch size
351 integer, optional, intent(in) :: verbose
352 !! Verbosity level
353 end subroutine compile
354
355 !! Interface for setting batch size
356 module subroutine set_batch_size(this, batch_size)
357 !! Set batch size
358 class(network_type), intent(inout) :: this
359 !! Instance of the network
360 integer, intent(in) :: batch_size
361 !! Batch size
362 end subroutine set_batch_size
363
364 !! Interface for setting network metrics
365 module subroutine set_metrics(this, metrics)
366 !! Set network metrics
367 class(network_type), intent(inout) :: this
368 !! Instance of the network
369 class(*), dimension(..), intent(in) :: metrics
370 !! Metrics
371 end subroutine set_metrics
372
373 !! Interface for setting network loss method
374 module subroutine set_loss(this, loss_method, verbose)
375 !! Set network loss method
376 class(network_type), intent(inout) :: this
377 !! Instance of the network
378 class(*), intent(in) :: loss_method
379 !! Loss method
380 integer, optional, intent(in) :: verbose
381 !! Verbosity level
382 end subroutine set_loss
383
384 !! Interface for setting network accuracy method
385 module subroutine set_accuracy(this, accuracy_method, verbose)
386 !! Set network accuracy method
387 class(network_type), intent(inout) :: this
388 !! Instance of the network
389 character(*), intent(in) :: accuracy_method
390 !! Accuracy method
391 integer, optional, intent(in) :: verbose
392 !! Verbosity level
393 end subroutine set_accuracy
394
395 !! Interface for resetting state of recurrent layers
396 module subroutine reset_state(this)
397 !! Reset hidden state of recurrent layers
398 class(network_type), intent(inout) :: this
399 !! Instance of the network
400 end subroutine reset_state
401
402 module subroutine set_training_mode(this, mode_store, layer_indices)
403 !! Put the network in training mode.
404 !! Layers such as dropout and batch normalisation use their training
405 !! behaviour after this call.
406 class(network_type), intent(inout) :: this
407 !! Instance of the network
408 logical, dimension(:), allocatable, intent(out), optional :: mode_store
409 !! Optional array to store the training mode of each layer
410 integer, dimension(:), intent(in), optional :: layer_indices
411 !! Optional array of layer indices to set to training mode.
412 end subroutine set_training_mode
413
414 module subroutine set_inference_mode(this, mode_store, layer_indices)
415 !! Put the network in inference mode.
416 !! Layers such as dropout and batch normalisation use their inference
417 !! behaviour after this call.
418 class(network_type), intent(inout) :: this
419 !! Instance of the network
420 logical, dimension(:), allocatable, intent(out), optional :: mode_store
421 !! Optional array to store the training mode of each layer
422 integer, dimension(:), intent(in), optional :: layer_indices
423 !! Optional array of layer indices to set to inference mode.
424 end subroutine set_inference_mode
425
426 module subroutine restore_mode(this, mode_store)
427 !! Restore the training/inference mode of layers to the values stored in
428 !! mode_store. This is used after temporarily switching
429 !! modes for prediction or evaluation on a training batch.
430 class(network_type), intent(inout) :: this
431 !! Instance of the network
432 logical, dimension(:), intent(in) :: mode_store
433 end subroutine restore_mode
434
435 !! Interface for saving input to network
436 module function save_input_to_network( this, input ) result(num_samples)
437 !! Convert and save polymorphic input to array or graph
438 class(network_type), intent(inout) :: this
439 !! Instance of network
440 class(*), dimension(..), intent(in) :: input
441 !! Input
442 integer :: num_samples
443 !! Number of samples
444 end function save_input_to_network
445
446 !! Interface for saving output to network
447 module subroutine save_output_to_network( this, output )
448 !! Convert and save polymorphic output to array or graph
449 class(network_type), intent(inout) :: this
450 !! Instance of network
451 class(*), dimension(:,:), intent(in) :: output
452 !! Output
453 end subroutine save_output_to_network
454
455 module function layer_from_id(this, id) result(layer)
456 !! Get the layer of the network from its ID
457 class(network_type), intent(in), target :: this
458 !! Instance of the network
459 integer, intent(in) :: id
460 !! Layer ID
461 class(base_layer_type), pointer :: layer
462 !! Layer pointer
463 end function layer_from_id
464
465
466 !! Interface for training the network
467 module subroutine train( &
468 this, input, output, num_epochs, batch_size, &
469 plateau_threshold, shuffle_batches, batch_print_step, verbose, &
470 print_precision, scientific_print, early_stopping, &
471 val_input, val_output &
472 )
473 !! Train the network
474 class(network_type), intent(inout) :: this
475 !! Instance of the network
476 class(*), dimension(..), intent(in) :: input
477 !! Input data
478 class(*), dimension(:,:), intent(in) :: output
479 !! Expected output data (data labels)
480 integer, intent(in) :: num_epochs
481 !! Number of epochs to train for
482 integer, optional, intent(in) :: batch_size
483 !! Batch size (DEPRECATED)
484 real(real32), optional, intent(in) :: plateau_threshold
485 !! Threshold for checking learning plateau
486 logical, optional, intent(in) :: shuffle_batches
487 !! Shuffle batch order
488 integer, optional, intent(in) :: batch_print_step
489 !! Print step for batch
490 integer, optional, intent(in) :: verbose
491 !! Verbosity level
492 integer, optional, intent(in) :: print_precision
493 !! Number of decimal places to print for training metrics
494 logical, optional, intent(in) :: scientific_print
495 !! Whether to print training metrics in scientific notation
496 logical, optional, intent(in) :: early_stopping
497 !! Whether to stop training early if learning plateau is detected
498 class(*), dimension(..), optional, intent(in) :: val_input
499 !! Validation input data
500 class(*), dimension(:,:), optional, intent(in) :: val_output
501 !! Validation expected output data
502 end subroutine train
503
504 !! Interface for testing the network
505 module subroutine test(this, input, output, verbose)
506 !! Test the network
507 class(network_type), intent(inout) :: this
508 !! Instance of the network
509 class(*), dimension(..), intent(in) :: input
510 !! Input data
511 class(*), dimension(:,:), intent(in) :: output
512 !! Expected output data (data labels)
513 integer, optional, intent(in) :: verbose
514 !! Verbosity level
515 end subroutine test
516
517 !! Interface for returning predicted results from supplied inputs
518 !! using the trained network
519 module function predict_real(this, input, verbose) result(output)
520 !! Get predicted results from supplied inputs using the trained network
521 class(network_type), intent(inout) :: this
522 !! Instance of the network
523 real(real32), dimension(..), intent(in) :: input
524 !! Input data
525 integer, optional, intent(in) :: verbose
526 !! Verbosity level
527 real(real32), dimension(:,:), allocatable :: output
528 !! Predicted output data
529 end function predict_real
530
531 module function predict_array_from_real( &
532 this, input, output_as_array, verbose &
533 ) result(output)
534 !! Get predicted results as array from supplied inputs using the trained network
535 class(network_type), intent(inout) :: this
536 !! Instance of the network
537 class(*), dimension(..), intent(in) :: input
538 !! Input data
539 logical, intent(in) :: output_as_array
540 !! Whether to output as array
541 integer, optional, intent(in) :: verbose
542 !! Verbosity level
543 type(array_type), dimension(:,:), allocatable :: output
544 !! Predicted output data as array
545 end function predict_array_from_real
546
547 !! Interface for returning predicted results from supplied inputs
548 !! using the trained network (graph input)
549 module function predict_graph1d(this, input, verbose) result(output)
550 !! Get predicted results from supplied inputs using the trained network
551 class(network_type), intent(inout) :: this
552 !! Instance of the network
553 type(graph_type), dimension(:), intent(in) :: input
554 !! Input data
555 integer, optional, intent(in) :: verbose
556 !! Verbosity level
557 type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: &
558 output
559 !! Predicted output data
560 end function predict_graph1d
561 module function predict_graph2d(this, input, verbose) result(output)
562 !! Get predicted results from supplied inputs using the trained network
563 class(network_type), intent(inout) :: this
564 !! Instance of the network
565 type(graph_type), dimension(:,:), intent(in) :: input
566 !! Input data
567 integer, optional, intent(in) :: verbose
568 !! Verbosity level
569 type(graph_type), dimension(size(this%leaf_vertices),size(input, 2)) :: &
570 output
571 !! Predicted output data
572 end function predict_graph2d
573
574 module function predict_array( this, input, verbose ) &
575 result(output)
576 !! Predict the output for a generic input
577 class(network_type), intent(inout) :: this
578 !! Instance of network
579 class(array_type), dimension(..), intent(in) :: input
580 !! Input graph
581 integer, intent(in), optional :: verbose
582 !! Verbosity level
583 type(array_type), dimension(:,:), allocatable :: output
584 end function predict_array
585
586 module function predict_generic( this, input, verbose, output_as_graph ) &
587 result(output)
588 !! Predict the output for a generic input
589 class(network_type), intent(inout) :: this
590 !! Instance of network
591 class(*), dimension(:,:), intent(in) :: input
592 !! Input graph
593 integer, intent(in), optional :: verbose
594 !! Verbosity level
595 logical, intent(in), optional :: output_as_graph
596 !! Boolean whether to output as graph
597 class(*), dimension(:,:), allocatable :: output
598 end function predict_generic
599
600 !! Interface for updating the learnable parameters of the network
601 !! based on gradients
602 module subroutine update(this)
603 !! Update the learnable parameters of the network based on gradients
604 class(network_type), intent(inout) :: this
605 !! Instance of the network
606 end subroutine update
607
608 !! Interface for generating vertex order
609 module subroutine build_vertex_order(this)
610 !! Generate vertex order
611 class(network_type), intent(inout) :: this
612 !! Instance of the network
613 end subroutine build_vertex_order
614
615 !! Interface for depth first search
616 recursive module subroutine dfs( &
617 this, vertex_index, visited, order, order_index &
618 )
619 !! Depth first search
620 class(network_type), intent(in) :: this
621 !! Instance of the network
622 integer, intent(in) :: vertex_index
623 !! Vertex index
624 logical, dimension(this%auto_graph%num_vertices), intent(inout) :: &
625 visited
626 !! Visited vertices
627 integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order
628 !! Order of vertices
629 integer, intent(inout) :: order_index
630 !! Index of order
631 end subroutine dfs
632
633 !! Interface for calculating root vertices
634 module subroutine build_root_vertices(this)
635 !! Calculate root vertices
636 class(network_type), intent(inout) :: this
637 !! Instance of the network
638 end subroutine build_root_vertices
639
640 !! Interface for calculating output vertices
641 module subroutine build_leaf_vertices(this)
642 !! Calculate output vertices
643 class(network_type), intent(inout) :: this
644 !! Instance of the network
645 end subroutine build_leaf_vertices
646
647 !! Interface for reducing two networks down to one
648 !! (i.e. add two networks - parallel)
649 module subroutine network_reduction(this, source)
650 !! Reduce two networks down to one (i.e. add two networks - parallel)
651 class(network_type), intent(inout) :: this
652 !! Instance of the network
653 type(network_type), intent(in) :: source
654 !! Source network
655 end subroutine network_reduction
656
657 !! Interface for copying a network
658 module subroutine network_copy(this, source)
659 !! Copy a network
660 class(network_type), intent(inout) :: this
661 !! Instance of the network
662 type(network_type), intent(in), target :: source
663 !! Source network
664 end subroutine network_copy
665
666 !! Interface for getting number of learnable parameters in the network
667 pure module function get_num_params(this) result(num_params)
668 !! Get number of learnable parameters in the network
669 class(network_type), intent(in) :: this
670 !! Instance of the network
671 integer :: num_params
672 !! Number of parameters
673 end function get_num_params
674
675 !! Interface for getting learnable parameters
676 pure module function get_params(this) result(params)
677 !! Get learnable parameters
678 class(network_type), intent(in) :: this
679 !! Instance of the network
680 real(real32), dimension(this%num_params) :: params
681 !! Learnable parameters
682 end function get_params
683
684 !! Interface for setting learnable parameters
685 module subroutine set_params(this, params)
686 !! Set learnable parameters
687 class(network_type), intent(inout) :: this
688 !! Instance of the network
689 real(real32), dimension(this%num_params), intent(in) :: params
690 !! Learnable parameters
691 end subroutine set_params
692
693 !! Interface for getting gradients of learnable parameters
694 pure module function get_gradients(this) result(gradients)
695 !! Get gradients of learnable parameters
696 class(network_type), intent(in) :: this
697 !! Instance of the network
698 real(real32), dimension(this%num_params) :: gradients
699 !! Gradients
700 end function get_gradients
701
702 !! Interface for setting learnable parameter gradients
703 module subroutine set_gradients(this, gradients)
704 !! Set learnable parameter gradients
705 class(network_type), intent(inout) :: this
706 !! Instance of the network
707 real(real32), dimension(..), intent(in) :: gradients
708 !! Gradients
709 end subroutine set_gradients
710
711 !! Interface for resetting learnable parameter gradients
712 module subroutine reset_gradients(this)
713 !! Reset learnable parameter gradients
714 class(network_type), intent(inout) :: this
715 !! Instance of the network
716 end subroutine reset_gradients
717
718 module function get_output(this) result(output)
719 class(network_type), intent(in) :: this
720 !! Instance of the network
721 type(array_type), dimension(:,:), allocatable :: output
722 !! Output
723 end function get_output
724
725 module function get_output_shape(this) result(output_shape)
726 class(network_type), intent(in) :: this
727 !! Instance of the network
728 integer, dimension(2) :: output_shape
729 !! Output shape
730 end function get_output_shape
731
732 module subroutine extract_output_real(this, output)
733 class(network_type), intent(in) :: this
734 !! Instance of network
735 real(real32), dimension(..), allocatable, intent(out) :: output
736 !! Output
737 end subroutine extract_output_real
738
739 module function accuracy_eval(this, output, start_index, end_index) &
740 result(accuracy)
741 !! Get the accuracy for the output
742 class(network_type), intent(in) :: this
743 !! Instance of network
744 class(*), dimension(:,:), intent(in) :: output
745 !! Output
746 integer, intent(in) :: start_index, end_index
747 !! Start and end batch indices
748 real(real32) :: accuracy
749 !! Accuracy value
750 end function accuracy_eval
751
752 module function loss_eval(this, start_index, end_index) result(loss)
753 !! Get the loss for the output
754 ! Arguments
755 class(network_type), intent(inout), target :: this
756 !! Instance of network
757 integer, intent(in) :: start_index, end_index
758 !! Start and end batch indices
759
760 type(array_type), pointer :: loss
761 end function loss_eval
762
763 !! Interface for forward pass
764 module subroutine forward_generic2d(this, input)
765 !! Forward pass for generic 2D input
766 class(network_type), intent(inout), target :: this
767 !! Instance of the network
768 class(*), dimension(:,:), intent(in) :: input
769 !! Input data
770 end subroutine forward_generic2d
771
772 module function forward_eval(this, input) result(output)
773 !! Forward pass evaluation
774 class(network_type), intent(inout), target :: this
775 !! Instance of the network
776 class(*), dimension(:,:), intent(in) :: input
777 !! Input data
778 type(array_type), pointer :: output(:,:)
779 !! Output data
780 end function forward_eval
781
782 module function forward_eval_multi(this, input) result(output)
783 !! Forward pass evaluation for multiple outputs
784 class(network_type), intent(inout), target :: this
785 !! Instance of the network
786 class(*), dimension(:,:), intent(in) :: input
787 !! Input data
788 type(array_ptr_type), pointer :: output(:)
789 !! Output data
790 end function forward_eval_multi
791
792 module subroutine nullify_graph(this)
793 !! Nullify graph data in the network to free memory
794 class(network_type), intent(inout) :: this
795 !! Instance of the network
796 end subroutine nullify_graph
797
798 module subroutine post_epoch_hook(this, epoch, loss, accuracy)
799 !! Hook called after each training epoch.
800 !! The default implementation is a no-op; override in a derived type to
801 !! add custom per-epoch behaviour (e.g. W&B metric logging).
802 class(network_type), intent(inout) :: this
803 !! Instance of the network
804 integer, intent(in) :: epoch
805 !! Current epoch number (1-based)
806 real(real32), intent(in) :: loss
807 !! Current loss value
808 real(real32), intent(in) :: accuracy
809 !! Current accuracy value
810 end subroutine post_epoch_hook
811
812 module function inverse_design_real( &
813 this, target, x_init, optimiser, steps &
814 ) result(x_opt)
815 !! Optimise input to match a target output (real inputs)
816 class(network_type), intent(inout), target :: this
817 !! Instance of the network
818 real(real32), dimension(:,:), intent(in) :: target
819 !! Target output values
820 real(real32), dimension(:,:), intent(in) :: x_init
821 !! Initial input values
822 class(base_optimiser_type), optional, intent(in) :: optimiser
823 !! Optimiser for input updates (defaults to network optimiser)
824 integer, intent(in) :: steps
825 !! Number of optimisation iterations
826 real(real32), dimension(size(x_init,1), size(x_init,2)) :: x_opt
827 !! Optimised input
828 end function inverse_design_real
829
830 module function inverse_design_array_0d( &
831 this, target, x_init, optimiser, steps &
832 ) result(x_opt)
833 !! Optimise input to match a target output (array_type inputs)
834 class(network_type), intent(inout), target :: this
835 !! Instance of the network
836 type(array_type), intent(in) :: target
837 !! Target output values
838 type(array_type), intent(in) :: x_init
839 !! Initial input values
840 class(base_optimiser_type), optional, intent(in) :: optimiser
841 !! Optimiser for input updates (defaults to network optimiser)
842 integer, intent(in) :: steps
843 !! Number of optimisation iterations
844 type(array_type) :: x_opt
845 !! Optimised input
846 end function inverse_design_array_0d
847
848 module function inverse_design_array_2d( &
849 this, target, x_init, optimiser, steps &
850 ) result(x_opt)
851 !! Optimise input to match a target output (array_type inputs)
852 class(network_type), intent(inout), target :: this
853 !! Instance of the network
854 type(array_type), dimension(:,:), intent(in) :: target
855 !! Target output values
856 type(array_type), dimension(:,:), intent(in) :: x_init
857 !! Initial input values
858 class(base_optimiser_type), optional, intent(in) :: optimiser
859 !! Optimiser for input updates (defaults to network optimiser)
860 integer, intent(in) :: steps
861 !! Number of optimisation iterations
862 type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt
863 !! Optimised input
864 end function inverse_design_array_2d
865 end interface
866
867 interface get_sample
868 #ifdef __flang__
869 module function get_sample_flang( &
870 input, start_index, end_index, batch_size &
871 ) result(sample)
872 !! Get a sample from a rank
873 implicit none
874 ! Arguments
875 integer, intent(in) :: start_index, end_index
876 !! Start and end indices
877 integer, intent(in) :: batch_size
878 !! Batch size
879 real(real32), dimension(..), intent(in) :: input
880 !! Input array
881 ! Local variables
882 real(real32), allocatable :: sample(:,:)
883 !! Sample array
884 end function get_sample_flang
885 #else
886 module function get_sample_ptr( &
887 input, start_index, end_index, batch_size &
888 ) result(sample_ptr)
889 !! Get a sample from a rank
890 implicit none
891 ! Arguments
892 integer, intent(in) :: start_index, end_index
893 !! Start and end indices
894 integer, intent(in) :: batch_size
895 !! Batch size
896 real(real32), dimension(..), intent(in), target :: input
897 !! Input array
898 ! Local variables
899 real(real32), pointer :: sample_ptr(:,:)
900 !! Pointer to sample
901 end function get_sample_ptr
902 #endif
903 module function get_sample_array( &
904 input, start_index, end_index, batch_size, as_graph&
905 ) result(sample)
906 !! Get sample for mixed input
907 integer, intent(in) :: start_index, end_index
908 !! Start and end indices
909 integer, intent(in) :: batch_size
910 !! Batch size
911 class(array_type), dimension(:,:), intent(in) :: input
912 !! Input array
913 logical, intent(in) :: as_graph
914 !! Boolean whether to treat the input as a graph
915 type(array_type), dimension(:,:), allocatable :: sample
916 !! Sample array
917 end function get_sample_array
918 module function get_sample_graph1d( &
919 input, start_index, end_index, batch_size &
920 ) result(sample)
921 !! Get sample for graph input
922 integer, intent(in) :: start_index, end_index
923 !! Start and end indices
924 integer, intent(in) :: batch_size
925 !! Batch size
926 class(graph_type), dimension(:), intent(in) :: input
927 !! Input array
928 type(graph_type), dimension(1, batch_size) :: sample
929 !! Sample array
930 end function get_sample_graph1d
931 module function get_sample_graph2d( &
932 input, start_index, end_index, batch_size &
933 ) result(sample)
934 !! Get sample for graph input
935 integer, intent(in) :: start_index, end_index
936 !! Start and end indices
937 integer, intent(in) :: batch_size
938 !! Batch size
939 class(graph_type), dimension(:,:), intent(in) :: input
940 !! Input array
941 type(graph_type), dimension(size(input,1), batch_size) :: sample
942 !! Sample array
943 end function get_sample_graph2d
944 end interface get_sample
945
946 end module athena__network
947