| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__network | ||
| 2 | !! Module containing the network class used to define a neural network | ||
| 3 | !! | ||
| 4 | !! This module contains the types and interfaces for the network class used | ||
| 5 | !! to define a neural network. | ||
| 6 | !! The network class is used to define a neural network with overloaded | ||
| 7 | !! procedures for training, testing, predicting, and updating the network. | ||
| 8 | !! The network class is also used to define the network structure and | ||
| 9 | !! compile the network with an optimiser, loss function, and accuracy | ||
| 10 | !! function. | ||
| 11 | use coreutils, only: real32 | ||
| 12 | use graphstruc, only: graph_type | ||
| 13 | use athena__metrics, only: metric_dict_type | ||
| 14 | use athena__optimiser, only: base_optimiser_type | ||
| 15 | use athena__loss, only: base_loss_type | ||
| 16 | use athena__accuracy, only: comp_acc_func => compute_accuracy_function | ||
| 17 | use athena__base_layer, only: base_layer_type | ||
| 18 | use diffstruc, only: array_type | ||
| 19 | use athena__misc_types, only: & | ||
| 20 | onnx_node_type, onnx_initialiser_type, onnx_tensor_type | ||
| 21 | use athena__container_layer, only: container_layer_type | ||
| 22 | use athena__diffstruc_extd, only: array_ptr_type | ||
| 23 | implicit none | ||
| 24 | |||
| 25 | |||
| 26 | private | ||
| 27 | |||
| 28 | public :: network_type | ||
| 29 | |||
| 30 | |||
| 31 | type :: network_type | ||
| 32 | !! Type for defining a neural network with overloaded procedures | ||
| 33 | character(len=:), allocatable :: name | ||
| 34 | !! Name of the network | ||
| 35 | real(real32) :: accuracy_val, loss_val | ||
| 36 | !! Accuracy and loss of the network | ||
| 37 | integer :: batch_size = 0 | ||
| 38 | !! Batch size | ||
| 39 | integer :: epoch = 0 | ||
| 40 | !! Epoch number | ||
| 41 | integer :: num_layers = 0 | ||
| 42 | !! Number of layers | ||
| 43 | integer :: num_outputs = 0 | ||
| 44 | !! Number of outputs | ||
| 45 | integer :: num_params = 0 | ||
| 46 | !! Number of parameters | ||
| 47 | logical :: use_graph_input = .false. | ||
| 48 | !! Boolean flag for graph input | ||
| 49 | logical :: use_graph_output = .false. | ||
| 50 | !! Boolean flag for graph output | ||
| 51 | class(base_optimiser_type), allocatable :: optimiser | ||
| 52 | !! Optimiser for the network | ||
| 53 | class(base_loss_type), allocatable :: loss | ||
| 54 | !! Loss method for the network | ||
| 55 | type(metric_dict_type), dimension(2) :: metrics | ||
| 56 | !! Metrics for the network | ||
| 57 | type(container_layer_type), allocatable, dimension(:) :: model | ||
| 58 | !! Model layers | ||
| 59 | character(len=:), allocatable :: loss_method, accuracy_method | ||
| 60 | !! Loss and accuracy method names | ||
| 61 | procedure(comp_acc_func), nopass, pointer :: get_accuracy => null() | ||
| 62 | !! Pointer to accuracy function | ||
| 63 | integer, dimension(:), allocatable :: vertex_order | ||
| 64 | !! Order of vertices | ||
| 65 | integer, dimension(:), allocatable :: root_vertices, leaf_vertices | ||
| 66 | !! Root and output vertices | ||
| 67 | type(graph_type) :: auto_graph | ||
| 68 | !! Graph structure for the network | ||
| 69 | |||
| 70 | ! Pre-computed forward pass navigation (populated during compile) | ||
| 71 | integer, dimension(:), allocatable :: fwd_layer_id | ||
| 72 | !! Layer ID for each vertex in forward order | ||
| 73 | integer, dimension(:), allocatable :: fwd_num_inputs | ||
| 74 | !! Number of input layers for each vertex in forward order | ||
| 75 | integer, dimension(:), allocatable :: fwd_parent_id | ||
| 76 | !! Parent layer ID for single-input vertices | ||
| 77 | integer, dimension(:), allocatable :: fwd_layer_type | ||
| 78 | !! Layer type: 0=input, 1=merge, 2=default | ||
| 79 | |||
| 80 | ! Pre-computed parameter segment layout (populated during compile) | ||
| 81 | integer :: param_num_segments = 0 | ||
| 82 | !! Number of parameter segments | ||
| 83 | integer, dimension(:), allocatable :: param_seg_layer | ||
| 84 | !! Layer index for each parameter segment | ||
| 85 | integer, dimension(:), allocatable :: param_seg_pidx | ||
| 86 | !! Param index within that layer for each segment | ||
| 87 | integer, dimension(:), allocatable :: param_seg_start | ||
| 88 | !! Start offset in flat parameter array | ||
| 89 | integer, dimension(:), allocatable :: param_seg_end | ||
| 90 | !! End offset in flat parameter array | ||
| 91 | |||
| 92 | type(array_type), dimension(:,:), allocatable :: input_array | ||
| 93 | !! Input array for the network | ||
| 94 | type(graph_type), dimension(:,:), allocatable :: input_graph | ||
| 95 | !! Input graph for the network | ||
| 96 | type(array_type), dimension(:,:), allocatable :: expected_array | ||
| 97 | !! Expected output array for the network | ||
| 98 | contains | ||
| 99 | procedure, pass(this) :: print | ||
| 100 | !! Print the network to file | ||
| 101 | procedure, pass(this) :: print_summary | ||
| 102 | !! Print a summary of the network architecture | ||
| 103 | procedure, pass(this) :: read | ||
| 104 | !! Read the network from a file | ||
| 105 | procedure, pass(this), private :: read_network_settings | ||
| 106 | !! Read network settings from a file | ||
| 107 | procedure, pass(this), private :: read_optimiser_settings | ||
| 108 | !! Read optimiser settings from a file | ||
| 109 | procedure, pass(this) :: build_from_onnx | ||
| 110 | !! Build network from ONNX nodes and initialisers | ||
| 111 | procedure, pass(this) :: add | ||
| 112 | !! Add a layer to the network | ||
| 113 | procedure, pass(this) :: reset | ||
| 114 | !! Reset the network | ||
| 115 | procedure, pass(this) :: compile | ||
| 116 | !! Compile the network | ||
| 117 | procedure, pass(this) :: set_batch_size | ||
| 118 | !! Set batch size | ||
| 119 | procedure, pass(this) :: set_metrics | ||
| 120 | !! Set network metrics | ||
| 121 | procedure, pass(this) :: set_loss | ||
| 122 | !! Set network loss method | ||
| 123 | procedure, pass(this) :: set_accuracy | ||
| 124 | !! Set network accuracy method | ||
| 125 | procedure, pass(this) :: reset_state | ||
| 126 | !! Reset hidden state of recurrent layers | ||
| 127 | procedure, pass(this) :: set_training_mode | ||
| 128 | !! Set training mode for layers with training/inference-specific behaviour | ||
| 129 | procedure, pass(this) :: set_inference_mode | ||
| 130 | !! Set inference mode for layers with training/inference-specific behaviour | ||
| 131 | procedure, pass(this), private :: restore_mode | ||
| 132 | !! Reset the training/inference mode of layers to the values stored in mode_store. | ||
| 133 | |||
| 134 | procedure, pass(this) :: save_input => save_input_to_network | ||
| 135 | !! Convert and save polymorphic input to array or graph | ||
| 136 | procedure, pass(this) :: save_output => save_output_to_network | ||
| 137 | !! Convert and save polymorphic output to array or graph | ||
| 138 | |||
| 139 | procedure, pass(this) :: layer_from_id | ||
| 140 | !! Get the layer of the network from its ID | ||
| 141 | |||
| 142 | procedure, pass(this) :: train | ||
| 143 | !! Train the network | ||
| 144 | procedure, pass(this) :: test | ||
| 145 | !! Test the network | ||
| 146 | |||
| 147 | procedure, pass(this) :: predict_real | ||
| 148 | !! Return predicted results from supplied inputs using the trained network | ||
| 149 | procedure, pass(this) :: predict_array_from_real | ||
| 150 | !! Return predicted results as array from supplied inputs using the trained network | ||
| 151 | procedure, pass(this) :: predict_graph1d, predict_graph2d | ||
| 152 | !! Return predicted results from supplied inputs using the trained network (graph input) | ||
| 153 | procedure, pass(this) :: predict_array | ||
| 154 | !! Predict array type output for a generic input | ||
| 155 | procedure, pass(this) :: predict_generic | ||
| 156 | !! Predict generic type output for a generic input | ||
| 157 | generic :: predict => & | ||
| 158 | predict_real, predict_graph1d, predict_graph2d, & | ||
| 159 | predict_array, predict_array_from_real | ||
| 160 | !! Predict function for different input types | ||
| 161 | |||
| 162 | |||
| 163 | procedure, pass(this), private :: dfs | ||
| 164 | !! Depth first search | ||
| 165 | procedure, pass(this), private :: build_vertex_order | ||
| 166 | !! Generate vertex order | ||
| 167 | procedure, pass(this), private :: build_root_vertices | ||
| 168 | !! Calculate root vertices | ||
| 169 | procedure, pass(this), private :: build_leaf_vertices | ||
| 170 | !! Calculate output vertices | ||
| 171 | |||
| 172 | procedure, pass(this) :: reduce => network_reduction | ||
| 173 | !! Reduce two networks down to one (i.e. add two networks - parallel) | ||
| 174 | procedure, pass(this) :: copy => network_copy | ||
| 175 | !! Copy a network | ||
| 176 | |||
| 177 | procedure, pass(this) :: get_num_params | ||
| 178 | !! Get number of learnable parameters in the network | ||
| 179 | procedure, pass(this) :: get_params | ||
| 180 | !! Get learnable parameters | ||
| 181 | procedure, pass(this) :: set_params | ||
| 182 | !! Set learnable parameters | ||
| 183 | procedure, pass(this) :: get_gradients | ||
| 184 | !! Get gradients of learnable parameters | ||
| 185 | procedure, pass(this) :: set_gradients | ||
| 186 | !! Set learnable parameter gradients | ||
| 187 | procedure, pass(this) :: reset_gradients | ||
| 188 | !! Reset learnable parameter gradients | ||
| 189 | procedure, pass(this) :: get_output | ||
| 190 | !! Get the output of the network | ||
| 191 | procedure, pass(this) :: get_output_shape | ||
| 192 | !! Get the output shape of the network | ||
| 193 | procedure, pass(this) :: extract_output => extract_output_real | ||
| 194 | !! Extract network output as real array (only works for single output layer models) | ||
| 195 | |||
| 196 | procedure, pass(this) :: forward => forward_generic2d | ||
| 197 | !! Forward pass for generic 2D input | ||
| 198 | procedure, pass(this) :: forward_eval | ||
| 199 | !! Forward pass and return pointer to output (only works for single output layer models) | ||
| 200 | procedure, pass(this) :: accuracy_eval | ||
| 201 | !! Get the accuracy for the output | ||
| 202 | procedure, pass(this) :: loss_eval | ||
| 203 | !! Get the loss for the output | ||
| 204 | procedure, pass(this) :: update | ||
| 205 | !! Update the learnable parameters of the network based on gradients | ||
| 206 | |||
| 207 | procedure, pass(this) :: nullify_graph | ||
| 208 | !! Nullify graph data in the network to free memory | ||
| 209 | |||
| 210 | procedure, pass(this) :: post_epoch_hook | ||
| 211 | !! Called after each training epoch; override in derived types for custom | ||
| 212 | !! per-epoch callbacks (e.g. logging to Weights & Biases). | ||
| 213 | |||
| 214 | procedure, pass(this), private :: inverse_design_real | ||
| 215 | !! Inverse design with real inputs | ||
| 216 | procedure, pass(this), private :: inverse_design_array_0d | ||
| 217 | !! Inverse design with 0d array_type inputs | ||
| 218 | procedure, pass(this), private :: inverse_design_array_2d | ||
| 219 | !! Inverse design with 2d array_type inputs | ||
| 220 | generic :: inverse_design => & | ||
| 221 | inverse_design_real, inverse_design_array_0d, inverse_design_array_2d | ||
| 222 | !! Optimise input to match a target output | ||
| 223 | end type network_type | ||
| 224 | |||
| 225 | interface network_type | ||
| 226 | !! Interface for setting up the network (network initialisation) | ||
| 227 | module function network_setup( & | ||
| 228 | layers, & | ||
| 229 | optimiser, loss_method, accuracy_method, & | ||
| 230 | metrics, batch_size & | ||
| 231 | ) result(network) | ||
| 232 | !! Set up the network | ||
| 233 | type(container_layer_type), dimension(:), intent(in) :: layers | ||
| 234 | !! Layers | ||
| 235 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 236 | !! Optimiser | ||
| 237 | class(*), optional, intent(in) :: loss_method | ||
| 238 | !! Loss method | ||
| 239 | character(*), optional, intent(in) :: accuracy_method | ||
| 240 | !! Accuracy method | ||
| 241 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 242 | !! Metrics | ||
| 243 | integer, optional, intent(in) :: batch_size | ||
| 244 | !! Batch size | ||
| 245 | type(network_type) :: network | ||
| 246 | !! Instance of the network | ||
| 247 | end function network_setup | ||
| 248 | end interface network_type | ||
| 249 | |||
| 250 | interface | ||
| 251 | !! Interface for printing the network to file | ||
| 252 | module subroutine print(this, file) | ||
| 253 | !! Print the network to file | ||
| 254 | class(network_type), intent(in) :: this | ||
| 255 | !! Instance of the network | ||
| 256 | character(*), intent(in) :: file | ||
| 257 | !! File name | ||
| 258 | end subroutine print | ||
| 259 | |||
| 260 | !! Interface for printing a summary of the network | ||
| 261 | module subroutine print_summary(this) | ||
| 262 | !! Print a summary of the network architecture | ||
| 263 | class(network_type), intent(in) :: this | ||
| 264 | !! Instance of the network | ||
| 265 | end subroutine print_summary | ||
| 266 | |||
| 267 | !! Interface for reading the network from a file | ||
| 268 | module subroutine read(this, file) | ||
| 269 | !! Read the network from a file | ||
| 270 | class(network_type), intent(inout) :: this | ||
| 271 | !! Instance of the network | ||
| 272 | character(*), intent(in) :: file | ||
| 273 | !! File name | ||
| 274 | end subroutine read | ||
| 275 | |||
| 276 | !! Interface for reading network settings from a file | ||
| 277 | module subroutine read_network_settings(this, unit) | ||
| 278 | !! Read network settings from a file | ||
| 279 | class(network_type), intent(inout) :: this | ||
| 280 | !! Instance of the network | ||
| 281 | integer, intent(in) :: unit | ||
| 282 | !! Unit number for input | ||
| 283 | end subroutine read_network_settings | ||
| 284 | |||
| 285 | !! Interface for reading optimiser settings from a file | ||
| 286 | module subroutine read_optimiser_settings(this, unit) | ||
| 287 | !! Read optimiser settings from a file | ||
| 288 | class(network_type), intent(inout) :: this | ||
| 289 | !! Instance of the network | ||
| 290 | integer, intent(in) :: unit | ||
| 291 | !! Unit number for input | ||
| 292 | end subroutine read_optimiser_settings | ||
| 293 | |||
| 294 | !! Interface for building network from ONNX nodes and initialisers | ||
| 295 | module subroutine build_from_onnx( & | ||
| 296 | this, nodes, initialisers, inputs, value_info, verbose & | ||
| 297 | ) | ||
| 298 | !! Build network from ONNX nodes and initialisers | ||
| 299 | class(network_type), intent(inout) :: this | ||
| 300 | !! Instance of the network | ||
| 301 | type(onnx_node_type), dimension(:), intent(in) :: nodes | ||
| 302 | !! Array of ONNX nodes | ||
| 303 | type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers | ||
| 304 | !! Array of ONNX initialisers | ||
| 305 | type(onnx_tensor_type), dimension(:), intent(in) :: inputs | ||
| 306 | !! Array of ONNX input tensors | ||
| 307 | type(onnx_tensor_type), dimension(:), intent(in) :: value_info | ||
| 308 | !! Array of ONNX value info tensors | ||
| 309 | integer, optional, intent(in) :: verbose | ||
| 310 | !! Verbosity level | ||
| 311 | end subroutine build_from_onnx | ||
| 312 | |||
| 313 | !! Interface for adding a layer to the network | ||
| 314 | module subroutine add(this, layer, input_list, output_list, operator) | ||
| 315 | !! Add a layer to the network | ||
| 316 | class(network_type), intent(inout) :: this | ||
| 317 | !! Instance of the network | ||
| 318 | class(base_layer_type), intent(in) :: layer | ||
| 319 | !! Layer to add | ||
| 320 | integer, dimension(:), intent(in), optional :: input_list, output_list | ||
| 321 | !! Input and output list | ||
| 322 | class(*), optional, intent(in) :: operator | ||
| 323 | !! Operator | ||
| 324 | end subroutine add | ||
| 325 | |||
| 326 | !! Interface for resetting the network | ||
| 327 | module subroutine reset(this) | ||
| 328 | !! Reset the network | ||
| 329 | class(network_type), intent(inout) :: this | ||
| 330 | !! Instance of the network | ||
| 331 | end subroutine reset | ||
| 332 | |||
| 333 | !! Interface for compiling the network | ||
| 334 | module subroutine compile( & | ||
| 335 | this, optimiser, loss_method, accuracy_method, & | ||
| 336 | metrics, batch_size, verbose & | ||
| 337 | ) | ||
| 338 | !! Compile the network | ||
| 339 | class(network_type), intent(inout) :: this | ||
| 340 | !! Instance of the network | ||
| 341 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 342 | !! Optimiser | ||
| 343 | class(*), optional, intent(in) :: loss_method | ||
| 344 | !! Loss method | ||
| 345 | character(*), optional, intent(in) :: accuracy_method | ||
| 346 | !! Accuracy method | ||
| 347 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 348 | !! Metrics | ||
| 349 | integer, optional, intent(in) :: batch_size | ||
| 350 | !! Batch size | ||
| 351 | integer, optional, intent(in) :: verbose | ||
| 352 | !! Verbosity level | ||
| 353 | end subroutine compile | ||
| 354 | |||
| 355 | !! Interface for setting batch size | ||
| 356 | module subroutine set_batch_size(this, batch_size) | ||
| 357 | !! Set batch size | ||
| 358 | class(network_type), intent(inout) :: this | ||
| 359 | !! Instance of the network | ||
| 360 | integer, intent(in) :: batch_size | ||
| 361 | !! Batch size | ||
| 362 | end subroutine set_batch_size | ||
| 363 | |||
| 364 | !! Interface for setting network metrics | ||
| 365 | module subroutine set_metrics(this, metrics) | ||
| 366 | !! Set network metrics | ||
| 367 | class(network_type), intent(inout) :: this | ||
| 368 | !! Instance of the network | ||
| 369 | class(*), dimension(..), intent(in) :: metrics | ||
| 370 | !! Metrics | ||
| 371 | end subroutine set_metrics | ||
| 372 | |||
| 373 | !! Interface for setting network loss method | ||
| 374 | module subroutine set_loss(this, loss_method, verbose) | ||
| 375 | !! Set network loss method | ||
| 376 | class(network_type), intent(inout) :: this | ||
| 377 | !! Instance of the network | ||
| 378 | class(*), intent(in) :: loss_method | ||
| 379 | !! Loss method | ||
| 380 | integer, optional, intent(in) :: verbose | ||
| 381 | !! Verbosity level | ||
| 382 | end subroutine set_loss | ||
| 383 | |||
| 384 | !! Interface for setting network accuracy method | ||
| 385 | module subroutine set_accuracy(this, accuracy_method, verbose) | ||
| 386 | !! Set network accuracy method | ||
| 387 | class(network_type), intent(inout) :: this | ||
| 388 | !! Instance of the network | ||
| 389 | character(*), intent(in) :: accuracy_method | ||
| 390 | !! Accuracy method | ||
| 391 | integer, optional, intent(in) :: verbose | ||
| 392 | !! Verbosity level | ||
| 393 | end subroutine set_accuracy | ||
| 394 | |||
| 395 | !! Interface for resetting state of recurrent layers | ||
| 396 | module subroutine reset_state(this) | ||
| 397 | !! Reset hidden state of recurrent layers | ||
| 398 | class(network_type), intent(inout) :: this | ||
| 399 | !! Instance of the network | ||
| 400 | end subroutine reset_state | ||
| 401 | |||
| 402 | module subroutine set_training_mode(this, mode_store, layer_indices) | ||
| 403 | !! Put the network in training mode. | ||
| 404 | !! Layers such as dropout and batch normalisation use their training | ||
| 405 | !! behaviour after this call. | ||
| 406 | class(network_type), intent(inout) :: this | ||
| 407 | !! Instance of the network | ||
| 408 | logical, dimension(:), allocatable, intent(out), optional :: mode_store | ||
| 409 | !! Optional array to store the training mode of each layer | ||
| 410 | integer, dimension(:), intent(in), optional :: layer_indices | ||
| 411 | !! Optional array of layer indices to set to training mode. | ||
| 412 | end subroutine set_training_mode | ||
| 413 | |||
| 414 | module subroutine set_inference_mode(this, mode_store, layer_indices) | ||
| 415 | !! Put the network in inference mode. | ||
| 416 | !! Layers such as dropout and batch normalisation use their inference | ||
| 417 | !! behaviour after this call. | ||
| 418 | class(network_type), intent(inout) :: this | ||
| 419 | !! Instance of the network | ||
| 420 | logical, dimension(:), allocatable, intent(out), optional :: mode_store | ||
| 421 | !! Optional array to store the training mode of each layer | ||
| 422 | integer, dimension(:), intent(in), optional :: layer_indices | ||
| 423 | !! Optional array of layer indices to set to inference mode. | ||
| 424 | end subroutine set_inference_mode | ||
| 425 | |||
| 426 | module subroutine restore_mode(this, mode_store) | ||
| 427 | !! Restore the training/inference mode of layers to the values stored in | ||
| 428 | !! mode_store. This is used after temporarily switching | ||
| 429 | !! modes for prediction or evaluation on a training batch. | ||
| 430 | class(network_type), intent(inout) :: this | ||
| 431 | !! Instance of the network | ||
| 432 | logical, dimension(:), intent(in) :: mode_store | ||
| 433 | end subroutine restore_mode | ||
| 434 | |||
| 435 | !! Interface for saving input to network | ||
| 436 | module function save_input_to_network( this, input ) result(num_samples) | ||
| 437 | !! Convert and save polymorphic input to array or graph | ||
| 438 | class(network_type), intent(inout) :: this | ||
| 439 | !! Instance of network | ||
| 440 | class(*), dimension(..), intent(in) :: input | ||
| 441 | !! Input | ||
| 442 | integer :: num_samples | ||
| 443 | !! Number of samples | ||
| 444 | end function save_input_to_network | ||
| 445 | |||
| 446 | !! Interface for saving output to network | ||
| 447 | module subroutine save_output_to_network( this, output ) | ||
| 448 | !! Convert and save polymorphic output to array or graph | ||
| 449 | class(network_type), intent(inout) :: this | ||
| 450 | !! Instance of network | ||
| 451 | class(*), dimension(:,:), intent(in) :: output | ||
| 452 | !! Output | ||
| 453 | end subroutine save_output_to_network | ||
| 454 | |||
| 455 | module function layer_from_id(this, id) result(layer) | ||
| 456 | !! Get the layer of the network from its ID | ||
| 457 | class(network_type), intent(in), target :: this | ||
| 458 | !! Instance of the network | ||
| 459 | integer, intent(in) :: id | ||
| 460 | !! Layer ID | ||
| 461 | class(base_layer_type), pointer :: layer | ||
| 462 | !! Layer pointer | ||
| 463 | end function layer_from_id | ||
| 464 | |||
| 465 | |||
| 466 | !! Interface for training the network | ||
| 467 | module subroutine train( & | ||
| 468 | this, input, output, num_epochs, batch_size, & | ||
| 469 | plateau_threshold, shuffle_batches, batch_print_step, verbose, & | ||
| 470 | print_precision, scientific_print, early_stopping, & | ||
| 471 | val_input, val_output & | ||
| 472 | ) | ||
| 473 | !! Train the network | ||
| 474 | class(network_type), intent(inout) :: this | ||
| 475 | !! Instance of the network | ||
| 476 | class(*), dimension(..), intent(in) :: input | ||
| 477 | !! Input data | ||
| 478 | class(*), dimension(:,:), intent(in) :: output | ||
| 479 | !! Expected output data (data labels) | ||
| 480 | integer, intent(in) :: num_epochs | ||
| 481 | !! Number of epochs to train for | ||
| 482 | integer, optional, intent(in) :: batch_size | ||
| 483 | !! Batch size (DEPRECATED) | ||
| 484 | real(real32), optional, intent(in) :: plateau_threshold | ||
| 485 | !! Threshold for checking learning plateau | ||
| 486 | logical, optional, intent(in) :: shuffle_batches | ||
| 487 | !! Shuffle batch order | ||
| 488 | integer, optional, intent(in) :: batch_print_step | ||
| 489 | !! Print step for batch | ||
| 490 | integer, optional, intent(in) :: verbose | ||
| 491 | !! Verbosity level | ||
| 492 | integer, optional, intent(in) :: print_precision | ||
| 493 | !! Number of decimal places to print for training metrics | ||
| 494 | logical, optional, intent(in) :: scientific_print | ||
| 495 | !! Whether to print training metrics in scientific notation | ||
| 496 | logical, optional, intent(in) :: early_stopping | ||
| 497 | !! Whether to stop training early if learning plateau is detected | ||
| 498 | class(*), dimension(..), optional, intent(in) :: val_input | ||
| 499 | !! Validation input data | ||
| 500 | class(*), dimension(:,:), optional, intent(in) :: val_output | ||
| 501 | !! Validation expected output data | ||
| 502 | end subroutine train | ||
| 503 | |||
| 504 | !! Interface for testing the network | ||
| 505 | module subroutine test(this, input, output, verbose) | ||
| 506 | !! Test the network | ||
| 507 | class(network_type), intent(inout) :: this | ||
| 508 | !! Instance of the network | ||
| 509 | class(*), dimension(..), intent(in) :: input | ||
| 510 | !! Input data | ||
| 511 | class(*), dimension(:,:), intent(in) :: output | ||
| 512 | !! Expected output data (data labels) | ||
| 513 | integer, optional, intent(in) :: verbose | ||
| 514 | !! Verbosity level | ||
| 515 | end subroutine test | ||
| 516 | |||
| 517 | !! Interface for returning predicted results from supplied inputs | ||
| 518 | !! using the trained network | ||
| 519 | module function predict_real(this, input, verbose) result(output) | ||
| 520 | !! Get predicted results from supplied inputs using the trained network | ||
| 521 | class(network_type), intent(inout) :: this | ||
| 522 | !! Instance of the network | ||
| 523 | real(real32), dimension(..), intent(in) :: input | ||
| 524 | !! Input data | ||
| 525 | integer, optional, intent(in) :: verbose | ||
| 526 | !! Verbosity level | ||
| 527 | real(real32), dimension(:,:), allocatable :: output | ||
| 528 | !! Predicted output data | ||
| 529 | end function predict_real | ||
| 530 | |||
| 531 | module function predict_array_from_real( & | ||
| 532 | this, input, output_as_array, verbose & | ||
| 533 | ) result(output) | ||
| 534 | !! Get predicted results as array from supplied inputs using the trained network | ||
| 535 | class(network_type), intent(inout) :: this | ||
| 536 | !! Instance of the network | ||
| 537 | class(*), dimension(..), intent(in) :: input | ||
| 538 | !! Input data | ||
| 539 | logical, intent(in) :: output_as_array | ||
| 540 | !! Whether to output as array | ||
| 541 | integer, optional, intent(in) :: verbose | ||
| 542 | !! Verbosity level | ||
| 543 | type(array_type), dimension(:,:), allocatable :: output | ||
| 544 | !! Predicted output data as array | ||
| 545 | end function predict_array_from_real | ||
| 546 | |||
| 547 | !! Interface for returning predicted results from supplied inputs | ||
| 548 | !! using the trained network (graph input) | ||
| 549 | module function predict_graph1d(this, input, verbose) result(output) | ||
| 550 | !! Get predicted results from supplied inputs using the trained network | ||
| 551 | class(network_type), intent(inout) :: this | ||
| 552 | !! Instance of the network | ||
| 553 | type(graph_type), dimension(:), intent(in) :: input | ||
| 554 | !! Input data | ||
| 555 | integer, optional, intent(in) :: verbose | ||
| 556 | !! Verbosity level | ||
| 557 | type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: & | ||
| 558 | output | ||
| 559 | !! Predicted output data | ||
| 560 | end function predict_graph1d | ||
| 561 | module function predict_graph2d(this, input, verbose) result(output) | ||
| 562 | !! Get predicted results from supplied inputs using the trained network | ||
| 563 | class(network_type), intent(inout) :: this | ||
| 564 | !! Instance of the network | ||
| 565 | type(graph_type), dimension(:,:), intent(in) :: input | ||
| 566 | !! Input data | ||
| 567 | integer, optional, intent(in) :: verbose | ||
| 568 | !! Verbosity level | ||
| 569 | type(graph_type), dimension(size(this%leaf_vertices),size(input, 2)) :: & | ||
| 570 | output | ||
| 571 | !! Predicted output data | ||
| 572 | end function predict_graph2d | ||
| 573 | |||
| 574 | module function predict_array( this, input, verbose ) & | ||
| 575 | result(output) | ||
| 576 | !! Predict the output for a generic input | ||
| 577 | class(network_type), intent(inout) :: this | ||
| 578 | !! Instance of network | ||
| 579 | class(array_type), dimension(..), intent(in) :: input | ||
| 580 | !! Input graph | ||
| 581 | integer, intent(in), optional :: verbose | ||
| 582 | !! Verbosity level | ||
| 583 | type(array_type), dimension(:,:), allocatable :: output | ||
| 584 | end function predict_array | ||
| 585 | |||
| 586 | module function predict_generic( this, input, verbose, output_as_graph ) & | ||
| 587 | result(output) | ||
| 588 | !! Predict the output for a generic input | ||
| 589 | class(network_type), intent(inout) :: this | ||
| 590 | !! Instance of network | ||
| 591 | class(*), dimension(:,:), intent(in) :: input | ||
| 592 | !! Input graph | ||
| 593 | integer, intent(in), optional :: verbose | ||
| 594 | !! Verbosity level | ||
| 595 | logical, intent(in), optional :: output_as_graph | ||
| 596 | !! Boolean whether to output as graph | ||
| 597 | class(*), dimension(:,:), allocatable :: output | ||
| 598 | end function predict_generic | ||
| 599 | |||
| 600 | !! Interface for updating the learnable parameters of the network | ||
| 601 | !! based on gradients | ||
| 602 | module subroutine update(this) | ||
| 603 | !! Update the learnable parameters of the network based on gradients | ||
| 604 | class(network_type), intent(inout) :: this | ||
| 605 | !! Instance of the network | ||
| 606 | end subroutine update | ||
| 607 | |||
| 608 | !! Interface for generating vertex order | ||
| 609 | module subroutine build_vertex_order(this) | ||
| 610 | !! Generate vertex order | ||
| 611 | class(network_type), intent(inout) :: this | ||
| 612 | !! Instance of the network | ||
| 613 | end subroutine build_vertex_order | ||
| 614 | |||
| 615 | !! Interface for depth first search | ||
| 616 | recursive module subroutine dfs( & | ||
| 617 | this, vertex_index, visited, order, order_index & | ||
| 618 | ) | ||
| 619 | !! Depth first search | ||
| 620 | class(network_type), intent(in) :: this | ||
| 621 | !! Instance of the network | ||
| 622 | integer, intent(in) :: vertex_index | ||
| 623 | !! Vertex index | ||
| 624 | logical, dimension(this%auto_graph%num_vertices), intent(inout) :: & | ||
| 625 | visited | ||
| 626 | !! Visited vertices | ||
| 627 | integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order | ||
| 628 | !! Order of vertices | ||
| 629 | integer, intent(inout) :: order_index | ||
| 630 | !! Index of order | ||
| 631 | end subroutine dfs | ||
| 632 | |||
| 633 | !! Interface for calculating root vertices | ||
| 634 | module subroutine build_root_vertices(this) | ||
| 635 | !! Calculate root vertices | ||
| 636 | class(network_type), intent(inout) :: this | ||
| 637 | !! Instance of the network | ||
| 638 | end subroutine build_root_vertices | ||
| 639 | |||
| 640 | !! Interface for calculating output vertices | ||
| 641 | module subroutine build_leaf_vertices(this) | ||
| 642 | !! Calculate output vertices | ||
| 643 | class(network_type), intent(inout) :: this | ||
| 644 | !! Instance of the network | ||
| 645 | end subroutine build_leaf_vertices | ||
| 646 | |||
| 647 | !! Interface for reducing two networks down to one | ||
| 648 | !! (i.e. add two networks - parallel) | ||
| 649 | module subroutine network_reduction(this, source) | ||
| 650 | !! Reduce two networks down to one (i.e. add two networks - parallel) | ||
| 651 | class(network_type), intent(inout) :: this | ||
| 652 | !! Instance of the network | ||
| 653 | type(network_type), intent(in) :: source | ||
| 654 | !! Source network | ||
| 655 | end subroutine network_reduction | ||
| 656 | |||
| 657 | !! Interface for copying a network | ||
| 658 | module subroutine network_copy(this, source) | ||
| 659 | !! Copy a network | ||
| 660 | class(network_type), intent(inout) :: this | ||
| 661 | !! Instance of the network | ||
| 662 | type(network_type), intent(in), target :: source | ||
| 663 | !! Source network | ||
| 664 | end subroutine network_copy | ||
| 665 | |||
| 666 | !! Interface for getting number of learnable parameters in the network | ||
| 667 | pure module function get_num_params(this) result(num_params) | ||
| 668 | !! Get number of learnable parameters in the network | ||
| 669 | class(network_type), intent(in) :: this | ||
| 670 | !! Instance of the network | ||
| 671 | integer :: num_params | ||
| 672 | !! Number of parameters | ||
| 673 | end function get_num_params | ||
| 674 | |||
| 675 | !! Interface for getting learnable parameters | ||
| 676 | pure module function get_params(this) result(params) | ||
| 677 | !! Get learnable parameters | ||
| 678 | class(network_type), intent(in) :: this | ||
| 679 | !! Instance of the network | ||
| 680 | real(real32), dimension(this%num_params) :: params | ||
| 681 | !! Learnable parameters | ||
| 682 | end function get_params | ||
| 683 | |||
| 684 | !! Interface for setting learnable parameters | ||
| 685 | module subroutine set_params(this, params) | ||
| 686 | !! Set learnable parameters | ||
| 687 | class(network_type), intent(inout) :: this | ||
| 688 | !! Instance of the network | ||
| 689 | real(real32), dimension(this%num_params), intent(in) :: params | ||
| 690 | !! Learnable parameters | ||
| 691 | end subroutine set_params | ||
| 692 | |||
| 693 | !! Interface for getting gradients of learnable parameters | ||
| 694 | pure module function get_gradients(this) result(gradients) | ||
| 695 | !! Get gradients of learnable parameters | ||
| 696 | class(network_type), intent(in) :: this | ||
| 697 | !! Instance of the network | ||
| 698 | real(real32), dimension(this%num_params) :: gradients | ||
| 699 | !! Gradients | ||
| 700 | end function get_gradients | ||
| 701 | |||
| 702 | !! Interface for setting learnable parameter gradients | ||
| 703 | module subroutine set_gradients(this, gradients) | ||
| 704 | !! Set learnable parameter gradients | ||
| 705 | class(network_type), intent(inout) :: this | ||
| 706 | !! Instance of the network | ||
| 707 | real(real32), dimension(..), intent(in) :: gradients | ||
| 708 | !! Gradients | ||
| 709 | end subroutine set_gradients | ||
| 710 | |||
| 711 | !! Interface for resetting learnable parameter gradients | ||
| 712 | module subroutine reset_gradients(this) | ||
| 713 | !! Reset learnable parameter gradients | ||
| 714 | class(network_type), intent(inout) :: this | ||
| 715 | !! Instance of the network | ||
| 716 | end subroutine reset_gradients | ||
| 717 | |||
| 718 | module function get_output(this) result(output) | ||
| 719 | class(network_type), intent(in) :: this | ||
| 720 | !! Instance of the network | ||
| 721 | type(array_type), dimension(:,:), allocatable :: output | ||
| 722 | !! Output | ||
| 723 | end function get_output | ||
| 724 | |||
| 725 | module function get_output_shape(this) result(output_shape) | ||
| 726 | class(network_type), intent(in) :: this | ||
| 727 | !! Instance of the network | ||
| 728 | integer, dimension(2) :: output_shape | ||
| 729 | !! Output shape | ||
| 730 | end function get_output_shape | ||
| 731 | |||
| 732 | module subroutine extract_output_real(this, output) | ||
| 733 | class(network_type), intent(in) :: this | ||
| 734 | !! Instance of network | ||
| 735 | real(real32), dimension(..), allocatable, intent(out) :: output | ||
| 736 | !! Output | ||
| 737 | end subroutine extract_output_real | ||
| 738 | |||
| 739 | module function accuracy_eval(this, output, start_index, end_index) & | ||
| 740 | result(accuracy) | ||
| 741 | !! Get the accuracy for the output | ||
| 742 | class(network_type), intent(in) :: this | ||
| 743 | !! Instance of network | ||
| 744 | class(*), dimension(:,:), intent(in) :: output | ||
| 745 | !! Output | ||
| 746 | integer, intent(in) :: start_index, end_index | ||
| 747 | !! Start and end batch indices | ||
| 748 | real(real32) :: accuracy | ||
| 749 | !! Accuracy value | ||
| 750 | end function accuracy_eval | ||
| 751 | |||
| 752 | module function loss_eval(this, start_index, end_index) result(loss) | ||
| 753 | !! Get the loss for the output | ||
| 754 | ! Arguments | ||
| 755 | class(network_type), intent(inout), target :: this | ||
| 756 | !! Instance of network | ||
| 757 | integer, intent(in) :: start_index, end_index | ||
| 758 | !! Start and end batch indices | ||
| 759 | |||
| 760 | type(array_type), pointer :: loss | ||
| 761 | end function loss_eval | ||
| 762 | |||
| 763 | !! Interface for forward pass | ||
| 764 | module subroutine forward_generic2d(this, input) | ||
| 765 | !! Forward pass for generic 2D input | ||
| 766 | class(network_type), intent(inout), target :: this | ||
| 767 | !! Instance of the network | ||
| 768 | class(*), dimension(:,:), intent(in) :: input | ||
| 769 | !! Input data | ||
| 770 | end subroutine forward_generic2d | ||
| 771 | |||
| 772 | module function forward_eval(this, input) result(output) | ||
| 773 | !! Forward pass evaluation | ||
| 774 | class(network_type), intent(inout), target :: this | ||
| 775 | !! Instance of the network | ||
| 776 | class(*), dimension(:,:), intent(in) :: input | ||
| 777 | !! Input data | ||
| 778 | type(array_type), pointer :: output(:,:) | ||
| 779 | !! Output data | ||
| 780 | end function forward_eval | ||
| 781 | |||
| 782 | module function forward_eval_multi(this, input) result(output) | ||
| 783 | !! Forward pass evaluation for multiple outputs | ||
| 784 | class(network_type), intent(inout), target :: this | ||
| 785 | !! Instance of the network | ||
| 786 | class(*), dimension(:,:), intent(in) :: input | ||
| 787 | !! Input data | ||
| 788 | type(array_ptr_type), pointer :: output(:) | ||
| 789 | !! Output data | ||
| 790 | end function forward_eval_multi | ||
| 791 | |||
| 792 | module subroutine nullify_graph(this) | ||
| 793 | !! Nullify graph data in the network to free memory | ||
| 794 | class(network_type), intent(inout) :: this | ||
| 795 | !! Instance of the network | ||
| 796 | end subroutine nullify_graph | ||
| 797 | |||
| 798 | module subroutine post_epoch_hook(this, epoch, loss, accuracy) | ||
| 799 | !! Hook called after each training epoch. | ||
| 800 | !! The default implementation is a no-op; override in a derived type to | ||
| 801 | !! add custom per-epoch behaviour (e.g. W&B metric logging). | ||
| 802 | class(network_type), intent(inout) :: this | ||
| 803 | !! Instance of the network | ||
| 804 | integer, intent(in) :: epoch | ||
| 805 | !! Current epoch number (1-based) | ||
| 806 | real(real32), intent(in) :: loss | ||
| 807 | !! Current loss value | ||
| 808 | real(real32), intent(in) :: accuracy | ||
| 809 | !! Current accuracy value | ||
| 810 | end subroutine post_epoch_hook | ||
| 811 | |||
| 812 | module function inverse_design_real( & | ||
| 813 | this, target, x_init, optimiser, steps & | ||
| 814 | ) result(x_opt) | ||
| 815 | !! Optimise input to match a target output (real inputs) | ||
| 816 | class(network_type), intent(inout), target :: this | ||
| 817 | !! Instance of the network | ||
| 818 | real(real32), dimension(:,:), intent(in) :: target | ||
| 819 | !! Target output values | ||
| 820 | real(real32), dimension(:,:), intent(in) :: x_init | ||
| 821 | !! Initial input values | ||
| 822 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 823 | !! Optimiser for input updates (defaults to network optimiser) | ||
| 824 | integer, intent(in) :: steps | ||
| 825 | !! Number of optimisation iterations | ||
| 826 | real(real32), dimension(size(x_init,1), size(x_init,2)) :: x_opt | ||
| 827 | !! Optimised input | ||
| 828 | end function inverse_design_real | ||
| 829 | |||
| 830 | module function inverse_design_array_0d( & | ||
| 831 | this, target, x_init, optimiser, steps & | ||
| 832 | ) result(x_opt) | ||
| 833 | !! Optimise input to match a target output (array_type inputs) | ||
| 834 | class(network_type), intent(inout), target :: this | ||
| 835 | !! Instance of the network | ||
| 836 | type(array_type), intent(in) :: target | ||
| 837 | !! Target output values | ||
| 838 | type(array_type), intent(in) :: x_init | ||
| 839 | !! Initial input values | ||
| 840 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 841 | !! Optimiser for input updates (defaults to network optimiser) | ||
| 842 | integer, intent(in) :: steps | ||
| 843 | !! Number of optimisation iterations | ||
| 844 | type(array_type) :: x_opt | ||
| 845 | !! Optimised input | ||
| 846 | end function inverse_design_array_0d | ||
| 847 | |||
| 848 | module function inverse_design_array_2d( & | ||
| 849 | this, target, x_init, optimiser, steps & | ||
| 850 | ) result(x_opt) | ||
| 851 | !! Optimise input to match a target output (array_type inputs) | ||
| 852 | class(network_type), intent(inout), target :: this | ||
| 853 | !! Instance of the network | ||
| 854 | type(array_type), dimension(:,:), intent(in) :: target | ||
| 855 | !! Target output values | ||
| 856 | type(array_type), dimension(:,:), intent(in) :: x_init | ||
| 857 | !! Initial input values | ||
| 858 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 859 | !! Optimiser for input updates (defaults to network optimiser) | ||
| 860 | integer, intent(in) :: steps | ||
| 861 | !! Number of optimisation iterations | ||
| 862 | type(array_type), dimension(size(x_init,1), size(x_init,2)) :: x_opt | ||
| 863 | !! Optimised input | ||
| 864 | end function inverse_design_array_2d | ||
| 865 | end interface | ||
| 866 | |||
| 867 | interface get_sample | ||
| 868 | #ifdef __flang__ | ||
| 869 | module function get_sample_flang( & | ||
| 870 | input, start_index, end_index, batch_size & | ||
| 871 | ) result(sample) | ||
| 872 | !! Get a sample from a rank | ||
| 873 | implicit none | ||
| 874 | ! Arguments | ||
| 875 | integer, intent(in) :: start_index, end_index | ||
| 876 | !! Start and end indices | ||
| 877 | integer, intent(in) :: batch_size | ||
| 878 | !! Batch size | ||
| 879 | real(real32), dimension(..), intent(in) :: input | ||
| 880 | !! Input array | ||
| 881 | ! Local variables | ||
| 882 | real(real32), allocatable :: sample(:,:) | ||
| 883 | !! Sample array | ||
| 884 | end function get_sample_flang | ||
| 885 | #else | ||
| 886 | module function get_sample_ptr( & | ||
| 887 | input, start_index, end_index, batch_size & | ||
| 888 | ) result(sample_ptr) | ||
| 889 | !! Get a sample from a rank | ||
| 890 | implicit none | ||
| 891 | ! Arguments | ||
| 892 | integer, intent(in) :: start_index, end_index | ||
| 893 | !! Start and end indices | ||
| 894 | integer, intent(in) :: batch_size | ||
| 895 | !! Batch size | ||
| 896 | real(real32), dimension(..), intent(in), target :: input | ||
| 897 | !! Input array | ||
| 898 | ! Local variables | ||
| 899 | real(real32), pointer :: sample_ptr(:,:) | ||
| 900 | !! Pointer to sample | ||
| 901 | end function get_sample_ptr | ||
| 902 | #endif | ||
| 903 | module function get_sample_array( & | ||
| 904 | input, start_index, end_index, batch_size, as_graph& | ||
| 905 | ) result(sample) | ||
| 906 | !! Get sample for mixed input | ||
| 907 | integer, intent(in) :: start_index, end_index | ||
| 908 | !! Start and end indices | ||
| 909 | integer, intent(in) :: batch_size | ||
| 910 | !! Batch size | ||
| 911 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 912 | !! Input array | ||
| 913 | logical, intent(in) :: as_graph | ||
| 914 | !! Boolean whether to treat the input as a graph | ||
| 915 | type(array_type), dimension(:,:), allocatable :: sample | ||
| 916 | !! Sample array | ||
| 917 | end function get_sample_array | ||
| 918 | module function get_sample_graph1d( & | ||
| 919 | input, start_index, end_index, batch_size & | ||
| 920 | ) result(sample) | ||
| 921 | !! Get sample for graph input | ||
| 922 | integer, intent(in) :: start_index, end_index | ||
| 923 | !! Start and end indices | ||
| 924 | integer, intent(in) :: batch_size | ||
| 925 | !! Batch size | ||
| 926 | class(graph_type), dimension(:), intent(in) :: input | ||
| 927 | !! Input array | ||
| 928 | type(graph_type), dimension(1, batch_size) :: sample | ||
| 929 | !! Sample array | ||
| 930 | end function get_sample_graph1d | ||
| 931 | module function get_sample_graph2d( & | ||
| 932 | input, start_index, end_index, batch_size & | ||
| 933 | ) result(sample) | ||
| 934 | !! Get sample for graph input | ||
| 935 | integer, intent(in) :: start_index, end_index | ||
| 936 | !! Start and end indices | ||
| 937 | integer, intent(in) :: batch_size | ||
| 938 | !! Batch size | ||
| 939 | class(graph_type), dimension(:,:), intent(in) :: input | ||
| 940 | !! Input array | ||
| 941 | type(graph_type), dimension(size(input,1), batch_size) :: sample | ||
| 942 | !! Sample array | ||
| 943 | end function get_sample_graph2d | ||
| 944 | end interface get_sample | ||
| 945 | |||
| 946 | ✗ | end module athena__network | |
| 947 |