| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__network | ||
| 2 | !! Module containing the network class used to define a neural network | ||
| 3 | !! | ||
| 4 | !! This module contains the types and interfaces for the network class used | ||
| 5 | !! to define a neural network. | ||
| 6 | !! The network class is used to define a neural network with overloaded | ||
| 7 | !! procedures for training, testing, predicting, and updating the network. | ||
| 8 | !! The network class is also used to define the network structure and | ||
| 9 | !! compile the network with an optimiser, loss function, and accuracy | ||
| 10 | !! function. | ||
| 11 | use coreutils, only: real32 | ||
| 12 | use graphstruc, only: graph_type | ||
| 13 | use athena__metrics, only: metric_dict_type | ||
| 14 | use athena__optimiser, only: base_optimiser_type | ||
| 15 | use athena__loss, only: base_loss_type | ||
| 16 | use athena__accuracy, only: comp_acc_func => compute_accuracy_function | ||
| 17 | use athena__base_layer, only: base_layer_type | ||
| 18 | use diffstruc, only: array_type | ||
| 19 | use athena__misc_types, only: & | ||
| 20 | onnx_node_type, onnx_initialiser_type, onnx_tensor_type | ||
| 21 | use athena__container_layer, only: container_layer_type | ||
| 22 | use athena__diffstruc_extd, only: array_ptr_type | ||
| 23 | implicit none | ||
| 24 | |||
| 25 | |||
| 26 | private | ||
| 27 | |||
| 28 | public :: network_type | ||
| 29 | |||
| 30 | |||
| 31 | type :: network_type | ||
| 32 | !! Type for defining a neural network with overloaded procedures | ||
| 33 | character(len=:), allocatable :: name | ||
| 34 | !! Name of the network | ||
| 35 | real(real32) :: accuracy_val, loss_val | ||
| 36 | !! Accuracy and loss of the network | ||
| 37 | integer :: batch_size = 0 | ||
| 38 | !! Batch size | ||
| 39 | integer :: epoch = 0 | ||
| 40 | !! Epoch number | ||
| 41 | integer :: num_layers = 0 | ||
| 42 | !! Number of layers | ||
| 43 | integer :: num_outputs = 0 | ||
| 44 | !! Number of outputs | ||
| 45 | integer :: num_params = 0 | ||
| 46 | !! Number of parameters | ||
| 47 | logical :: use_graph_input = .false. | ||
| 48 | !! Boolean flag for graph input | ||
| 49 | logical :: use_graph_output = .false. | ||
| 50 | !! Boolean flag for graph output | ||
| 51 | class(base_optimiser_type), allocatable :: optimiser | ||
| 52 | !! Optimiser for the network | ||
| 53 | class(base_loss_type), allocatable :: loss | ||
| 54 | !! Loss method for the network | ||
| 55 | type(metric_dict_type), dimension(2) :: metrics | ||
| 56 | !! Metrics for the network | ||
| 57 | type(container_layer_type), allocatable, dimension(:) :: model | ||
| 58 | !! Model layers | ||
| 59 | character(len=:), allocatable :: loss_method, accuracy_method | ||
| 60 | !! Loss and accuracy method names | ||
| 61 | procedure(comp_acc_func), nopass, pointer :: get_accuracy => null() | ||
| 62 | !! Pointer to accuracy function | ||
| 63 | integer, dimension(:), allocatable :: vertex_order | ||
| 64 | !! Order of vertices | ||
| 65 | integer, dimension(:), allocatable :: root_vertices, leaf_vertices | ||
| 66 | !! Root and output vertices | ||
| 67 | type(graph_type) :: auto_graph | ||
| 68 | !! Graph structure for the network | ||
| 69 | |||
| 70 | type(array_type), dimension(:,:), allocatable :: input_array | ||
| 71 | !! Input array for the network | ||
| 72 | type(graph_type), dimension(:,:), allocatable :: input_graph | ||
| 73 | !! Input graph for the network | ||
| 74 | type(array_type), dimension(:,:), allocatable :: expected_array | ||
| 75 | !! Expected output array for the network | ||
| 76 | contains | ||
| 77 | procedure, pass(this) :: print | ||
| 78 | !! Print the network to file | ||
| 79 | procedure, pass(this) :: print_summary | ||
| 80 | !! Print a summary of the network architecture | ||
| 81 | procedure, pass(this) :: read | ||
| 82 | !! Read the network from a file | ||
| 83 | procedure, pass(this), private :: read_network_settings | ||
| 84 | !! Read network settings from a file | ||
| 85 | procedure, pass(this), private :: read_optimiser_settings | ||
| 86 | !! Read optimiser settings from a file | ||
| 87 | procedure, pass(this) :: build_from_onnx | ||
| 88 | !! Build network from ONNX nodes and initialisers | ||
| 89 | procedure, pass(this) :: add | ||
| 90 | !! Add a layer to the network | ||
| 91 | procedure, pass(this) :: reset | ||
| 92 | !! Reset the network | ||
| 93 | procedure, pass(this) :: compile | ||
| 94 | !! Compile the network | ||
| 95 | procedure, pass(this) :: set_batch_size | ||
| 96 | !! Set batch size | ||
| 97 | procedure, pass(this) :: set_metrics | ||
| 98 | !! Set network metrics | ||
| 99 | procedure, pass(this) :: set_loss | ||
| 100 | !! Set network loss method | ||
| 101 | procedure, pass(this) :: set_accuracy | ||
| 102 | !! Set network accuracy method | ||
| 103 | procedure, pass(this) :: reset_state | ||
| 104 | !! Reset hidden state of recurrent layers | ||
| 105 | |||
| 106 | procedure, pass(this) :: save_input => save_input_to_network | ||
| 107 | !! Convert and save polymorphic input to array or graph | ||
| 108 | procedure, pass(this) :: save_output => save_output_to_network | ||
| 109 | !! Convert and save polymorphic output to array or graph | ||
| 110 | |||
| 111 | procedure, pass(this) :: layer_from_id | ||
| 112 | !! Get the layer of the network from its ID | ||
| 113 | |||
| 114 | procedure, pass(this) :: train | ||
| 115 | !! Train the network | ||
| 116 | procedure, pass(this) :: test | ||
| 117 | !! Test the network | ||
| 118 | |||
| 119 | procedure, pass(this) :: predict_real | ||
| 120 | !! Return predicted results from supplied inputs using the trained network | ||
| 121 | procedure, pass(this) :: predict_array_from_real | ||
| 122 | !! Return predicted results as array from supplied inputs using the trained network | ||
| 123 | procedure, pass(this) :: predict_graph1d, predict_graph2d | ||
| 124 | !! Return predicted results from supplied inputs using the trained network (graph input) | ||
| 125 | procedure, pass(this) :: predict_array | ||
| 126 | !! Predict array type output for a generic input | ||
| 127 | procedure, pass(this) :: predict_generic | ||
| 128 | !! Predict generic type output for a generic input | ||
| 129 | generic :: predict => & | ||
| 130 | predict_real, predict_graph1d, predict_graph2d, & | ||
| 131 | predict_array, predict_array_from_real | ||
| 132 | !! Predict function for different input types | ||
| 133 | |||
| 134 | |||
| 135 | procedure, pass(this), private :: dfs | ||
| 136 | !! Depth first search | ||
| 137 | procedure, pass(this), private :: build_vertex_order | ||
| 138 | !! Generate vertex order | ||
| 139 | procedure, pass(this), private :: build_root_vertices | ||
| 140 | !! Calculate root vertices | ||
| 141 | procedure, pass(this), private :: build_leaf_vertices | ||
| 142 | !! Calculate output vertices | ||
| 143 | |||
| 144 | procedure, pass(this) :: reduce => network_reduction | ||
| 145 | !! Reduce two networks down to one (i.e. add two networks - parallel) | ||
| 146 | procedure, pass(this) :: copy => network_copy | ||
| 147 | !! Copy a network | ||
| 148 | |||
| 149 | procedure, pass(this) :: get_num_params | ||
| 150 | !! Get number of learnable parameters in the network | ||
| 151 | procedure, pass(this) :: get_params | ||
| 152 | !! Get learnable parameters | ||
| 153 | procedure, pass(this) :: set_params | ||
| 154 | !! Set learnable parameters | ||
| 155 | procedure, pass(this) :: get_gradients | ||
| 156 | !! Get gradients of learnable parameters | ||
| 157 | procedure, pass(this) :: set_gradients | ||
| 158 | !! Set learnable parameter gradients | ||
| 159 | procedure, pass(this) :: reset_gradients | ||
| 160 | !! Reset learnable parameter gradients | ||
| 161 | procedure, pass(this) :: get_output | ||
| 162 | !! Get the output of the network | ||
| 163 | procedure, pass(this) :: get_output_shape | ||
| 164 | !! Get the output shape of the network | ||
| 165 | procedure, pass(this) :: extract_output => extract_output_real | ||
| 166 | !! Extract network output as real array (only works for single output layer models) | ||
| 167 | |||
| 168 | procedure, pass(this) :: forward => forward_generic2d | ||
| 169 | !! Forward pass for generic 2D input | ||
| 170 | procedure, pass(this) :: forward_eval | ||
| 171 | !! Forward pass and return pointer to output (only works for single output layer models) | ||
| 172 | procedure, pass(this) :: accuracy_eval | ||
| 173 | !! Get the accuracy for the output | ||
| 174 | procedure, pass(this) :: loss_eval | ||
| 175 | !! Get the loss for the output | ||
| 176 | procedure, pass(this) :: update | ||
| 177 | !! Update the learnable parameters of the network based on gradients | ||
| 178 | |||
| 179 | procedure, pass(this) :: nullify_graph | ||
| 180 | !! Nullify graph data in the network to free memory | ||
| 181 | end type network_type | ||
| 182 | |||
| 183 | interface network_type | ||
| 184 | !! Interface for setting up the network (network initialisation) | ||
| 185 | module function network_setup( & | ||
| 186 | layers, & | ||
| 187 | optimiser, loss_method, accuracy_method, & | ||
| 188 | metrics, batch_size & | ||
| 189 | ) result(network) | ||
| 190 | !! Set up the network | ||
| 191 | type(container_layer_type), dimension(:), intent(in) :: layers | ||
| 192 | !! Layers | ||
| 193 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 194 | !! Optimiser | ||
| 195 | class(*), optional, intent(in) :: loss_method | ||
| 196 | !! Loss method | ||
| 197 | character(*), optional, intent(in) :: accuracy_method | ||
| 198 | !! Accuracy method | ||
| 199 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 200 | !! Metrics | ||
| 201 | integer, optional, intent(in) :: batch_size | ||
| 202 | !! Batch size | ||
| 203 | type(network_type) :: network | ||
| 204 | !! Instance of the network | ||
| 205 | end function network_setup | ||
| 206 | end interface network_type | ||
| 207 | |||
| 208 | interface | ||
| 209 | !! Interface for printing the network to file | ||
| 210 | module subroutine print(this, file) | ||
| 211 | !! Print the network to file | ||
| 212 | class(network_type), intent(in) :: this | ||
| 213 | !! Instance of the network | ||
| 214 | character(*), intent(in) :: file | ||
| 215 | !! File name | ||
| 216 | end subroutine print | ||
| 217 | |||
| 218 | !! Interface for printing a summary of the network | ||
| 219 | module subroutine print_summary(this) | ||
| 220 | !! Print a summary of the network architecture | ||
| 221 | class(network_type), intent(in) :: this | ||
| 222 | !! Instance of the network | ||
| 223 | end subroutine print_summary | ||
| 224 | |||
| 225 | !! Interface for reading the network from a file | ||
| 226 | module subroutine read(this, file) | ||
| 227 | !! Read the network from a file | ||
| 228 | class(network_type), intent(inout) :: this | ||
| 229 | !! Instance of the network | ||
| 230 | character(*), intent(in) :: file | ||
| 231 | !! File name | ||
| 232 | end subroutine read | ||
| 233 | |||
| 234 | !! Interface for reading network settings from a file | ||
| 235 | module subroutine read_network_settings(this, unit) | ||
| 236 | !! Read network settings from a file | ||
| 237 | class(network_type), intent(inout) :: this | ||
| 238 | !! Instance of the network | ||
| 239 | integer, intent(in) :: unit | ||
| 240 | !! Unit number for input | ||
| 241 | end subroutine read_network_settings | ||
| 242 | |||
| 243 | !! Interface for reading optimiser settings from a file | ||
| 244 | module subroutine read_optimiser_settings(this, unit) | ||
| 245 | !! Read optimiser settings from a file | ||
| 246 | class(network_type), intent(inout) :: this | ||
| 247 | !! Instance of the network | ||
| 248 | integer, intent(in) :: unit | ||
| 249 | !! Unit number for input | ||
| 250 | end subroutine read_optimiser_settings | ||
| 251 | |||
| 252 | !! Interface for building network from ONNX nodes and initialisers | ||
| 253 | module subroutine build_from_onnx( & | ||
| 254 | this, nodes, initialisers, inputs, value_info, verbose & | ||
| 255 | ) | ||
| 256 | !! Build network from ONNX nodes and initialisers | ||
| 257 | class(network_type), intent(inout) :: this | ||
| 258 | !! Instance of the network | ||
| 259 | type(onnx_node_type), dimension(:), intent(in) :: nodes | ||
| 260 | !! Array of ONNX nodes | ||
| 261 | type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers | ||
| 262 | !! Array of ONNX initialisers | ||
| 263 | type(onnx_tensor_type), dimension(:), intent(in) :: inputs | ||
| 264 | !! Array of ONNX input tensors | ||
| 265 | type(onnx_tensor_type), dimension(:), intent(in) :: value_info | ||
| 266 | !! Array of ONNX value info tensors | ||
| 267 | integer, optional, intent(in) :: verbose | ||
| 268 | !! Verbosity level | ||
| 269 | end subroutine build_from_onnx | ||
| 270 | |||
| 271 | !! Interface for adding a layer to the network | ||
| 272 | module subroutine add(this, layer, input_list, output_list, operator) | ||
| 273 | !! Add a layer to the network | ||
| 274 | class(network_type), intent(inout) :: this | ||
| 275 | !! Instance of the network | ||
| 276 | class(base_layer_type), intent(in) :: layer | ||
| 277 | !! Layer to add | ||
| 278 | integer, dimension(:), intent(in), optional :: input_list, output_list | ||
| 279 | !! Input and output list | ||
| 280 | class(*), optional, intent(in) :: operator | ||
| 281 | !! Operator | ||
| 282 | end subroutine add | ||
| 283 | |||
| 284 | !! Interface for resetting the network | ||
| 285 | module subroutine reset(this) | ||
| 286 | !! Reset the network | ||
| 287 | class(network_type), intent(inout) :: this | ||
| 288 | !! Instance of the network | ||
| 289 | end subroutine reset | ||
| 290 | |||
| 291 | !! Interface for compiling the network | ||
| 292 | module subroutine compile( & | ||
| 293 | this, optimiser, loss_method, accuracy_method, & | ||
| 294 | metrics, batch_size, verbose & | ||
| 295 | ) | ||
| 296 | !! Compile the network | ||
| 297 | class(network_type), intent(inout) :: this | ||
| 298 | !! Instance of the network | ||
| 299 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 300 | !! Optimiser | ||
| 301 | class(*), optional, intent(in) :: loss_method | ||
| 302 | !! Loss method | ||
| 303 | character(*), optional, intent(in) :: accuracy_method | ||
| 304 | !! Accuracy method | ||
| 305 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 306 | !! Metrics | ||
| 307 | integer, optional, intent(in) :: batch_size | ||
| 308 | !! Batch size | ||
| 309 | integer, optional, intent(in) :: verbose | ||
| 310 | !! Verbosity level | ||
| 311 | end subroutine compile | ||
| 312 | |||
| 313 | !! Interface for setting batch size | ||
| 314 | module subroutine set_batch_size(this, batch_size) | ||
| 315 | !! Set batch size | ||
| 316 | class(network_type), intent(inout) :: this | ||
| 317 | !! Instance of the network | ||
| 318 | integer, intent(in) :: batch_size | ||
| 319 | !! Batch size | ||
| 320 | end subroutine set_batch_size | ||
| 321 | |||
| 322 | !! Interface for setting network metrics | ||
| 323 | module subroutine set_metrics(this, metrics) | ||
| 324 | !! Set network metrics | ||
| 325 | class(network_type), intent(inout) :: this | ||
| 326 | !! Instance of the network | ||
| 327 | class(*), dimension(..), intent(in) :: metrics | ||
| 328 | !! Metrics | ||
| 329 | end subroutine set_metrics | ||
| 330 | |||
| 331 | !! Interface for setting network loss method | ||
| 332 | module subroutine set_loss(this, loss_method, verbose) | ||
| 333 | !! Set network loss method | ||
| 334 | class(network_type), intent(inout) :: this | ||
| 335 | !! Instance of the network | ||
| 336 | class(*), intent(in) :: loss_method | ||
| 337 | !! Loss method | ||
| 338 | integer, optional, intent(in) :: verbose | ||
| 339 | !! Verbosity level | ||
| 340 | end subroutine set_loss | ||
| 341 | |||
| 342 | !! Interface for setting network accuracy method | ||
| 343 | module subroutine set_accuracy(this, accuracy_method, verbose) | ||
| 344 | !! Set network accuracy method | ||
| 345 | class(network_type), intent(inout) :: this | ||
| 346 | !! Instance of the network | ||
| 347 | character(*), intent(in) :: accuracy_method | ||
| 348 | !! Accuracy method | ||
| 349 | integer, optional, intent(in) :: verbose | ||
| 350 | !! Verbosity level | ||
| 351 | end subroutine set_accuracy | ||
| 352 | |||
| 353 | !! Interface for resetting state of recurrent layers | ||
| 354 | module subroutine reset_state(this) | ||
| 355 | !! Reset hidden state of recurrent layers | ||
| 356 | class(network_type), intent(inout) :: this | ||
| 357 | !! Instance of the network | ||
| 358 | end subroutine reset_state | ||
| 359 | |||
| 360 | !! Interface for saving input to network | ||
| 361 | module function save_input_to_network( this, input ) result(num_samples) | ||
| 362 | !! Convert and save polymorphic input to array or graph | ||
| 363 | class(network_type), intent(inout) :: this | ||
| 364 | !! Instance of network | ||
| 365 | class(*), dimension(..), intent(in) :: input | ||
| 366 | !! Input | ||
| 367 | integer :: num_samples | ||
| 368 | !! Number of samples | ||
| 369 | end function save_input_to_network | ||
| 370 | |||
| 371 | !! Interface for saving output to network | ||
| 372 | module subroutine save_output_to_network( this, output ) | ||
| 373 | !! Convert and save polymorphic output to array or graph | ||
| 374 | class(network_type), intent(inout) :: this | ||
| 375 | !! Instance of network | ||
| 376 | class(*), dimension(:,:), intent(in) :: output | ||
| 377 | !! Output | ||
| 378 | end subroutine save_output_to_network | ||
| 379 | |||
| 380 | module function layer_from_id(this, id) result(layer) | ||
| 381 | !! Get the layer of the network from its ID | ||
| 382 | class(network_type), intent(in), target :: this | ||
| 383 | !! Instance of the network | ||
| 384 | integer, intent(in) :: id | ||
| 385 | !! Layer ID | ||
| 386 | class(base_layer_type), pointer :: layer | ||
| 387 | !! Layer pointer | ||
| 388 | end function layer_from_id | ||
| 389 | |||
| 390 | |||
| 391 | !! Interface for training the network | ||
| 392 | module subroutine train( & | ||
| 393 | this, input, output, num_epochs, batch_size, & | ||
| 394 | plateau_threshold, shuffle_batches, batch_print_step, verbose & | ||
| 395 | ) | ||
| 396 | !! Train the network | ||
| 397 | class(network_type), intent(inout) :: this | ||
| 398 | !! Instance of the network | ||
| 399 | class(*), dimension(..), intent(in) :: input | ||
| 400 | !! Input data | ||
| 401 | class(*), dimension(:,:), intent(in) :: output | ||
| 402 | !! Expected output data (data labels) | ||
| 403 | integer, intent(in) :: num_epochs | ||
| 404 | !! Number of epochs to train for | ||
| 405 | integer, optional, intent(in) :: batch_size | ||
| 406 | !! Batch size (DEPRECATED) | ||
| 407 | real(real32), optional, intent(in) :: plateau_threshold | ||
| 408 | !! Threshold for checking learning plateau | ||
| 409 | logical, optional, intent(in) :: shuffle_batches | ||
| 410 | !! Shuffle batch order | ||
| 411 | integer, optional, intent(in) :: batch_print_step | ||
| 412 | !! Print step for batch | ||
| 413 | integer, optional, intent(in) :: verbose | ||
| 414 | !! Verbosity level | ||
| 415 | end subroutine train | ||
| 416 | |||
| 417 | !! Interface for testing the network | ||
| 418 | module subroutine test(this, input, output, verbose) | ||
| 419 | !! Test the network | ||
| 420 | class(network_type), intent(inout) :: this | ||
| 421 | !! Instance of the network | ||
| 422 | class(*), dimension(..), intent(in) :: input | ||
| 423 | !! Input data | ||
| 424 | class(*), dimension(:,:), intent(in) :: output | ||
| 425 | !! Expected output data (data labels) | ||
| 426 | integer, optional, intent(in) :: verbose | ||
| 427 | !! Verbosity level | ||
| 428 | end subroutine test | ||
| 429 | |||
| 430 | !! Interface for returning predicted results from supplied inputs | ||
| 431 | !! using the trained network | ||
| 432 | module function predict_real(this, input, verbose) result(output) | ||
| 433 | !! Get predicted results from supplied inputs using the trained network | ||
| 434 | class(network_type), intent(inout) :: this | ||
| 435 | !! Instance of the network | ||
| 436 | real(real32), dimension(..), intent(in) :: input | ||
| 437 | !! Input data | ||
| 438 | integer, optional, intent(in) :: verbose | ||
| 439 | !! Verbosity level | ||
| 440 | real(real32), dimension(:,:), allocatable :: output | ||
| 441 | !! Predicted output data | ||
| 442 | end function predict_real | ||
| 443 | |||
| 444 | module function predict_array_from_real( & | ||
| 445 | this, input, output_as_array, verbose & | ||
| 446 | ) result(output) | ||
| 447 | !! Get predicted results as array from supplied inputs using the trained network | ||
| 448 | class(network_type), intent(inout) :: this | ||
| 449 | !! Instance of the network | ||
| 450 | class(*), dimension(..), intent(in) :: input | ||
| 451 | !! Input data | ||
| 452 | logical, intent(in) :: output_as_array | ||
| 453 | !! Whether to output as array | ||
| 454 | integer, optional, intent(in) :: verbose | ||
| 455 | !! Verbosity level | ||
| 456 | type(array_type), dimension(:,:), allocatable :: output | ||
| 457 | !! Predicted output data as array | ||
| 458 | end function predict_array_from_real | ||
| 459 | |||
| 460 | !! Interface for returning predicted results from supplied inputs | ||
| 461 | !! using the trained network (graph input) | ||
| 462 | module function predict_graph1d(this, input, verbose) result(output) | ||
| 463 | !! Get predicted results from supplied inputs using the trained network | ||
| 464 | class(network_type), intent(inout) :: this | ||
| 465 | !! Instance of the network | ||
| 466 | type(graph_type), dimension(:), intent(in) :: input | ||
| 467 | !! Input data | ||
| 468 | integer, optional, intent(in) :: verbose | ||
| 469 | !! Verbosity level | ||
| 470 | type(graph_type), dimension(size(this%leaf_vertices),size(input)) :: & | ||
| 471 | output | ||
| 472 | !! Predicted output data | ||
| 473 | end function predict_graph1d | ||
| 474 | module function predict_graph2d(this, input, verbose) result(output) | ||
| 475 | !! Get predicted results from supplied inputs using the trained network | ||
| 476 | class(network_type), intent(inout) :: this | ||
| 477 | !! Instance of the network | ||
| 478 | type(graph_type), dimension(:,:), intent(in) :: input | ||
| 479 | !! Input data | ||
| 480 | integer, optional, intent(in) :: verbose | ||
| 481 | !! Verbosity level | ||
| 482 | type(graph_type), dimension(size(this%leaf_vertices),size(input, 2)) :: & | ||
| 483 | output | ||
| 484 | !! Predicted output data | ||
| 485 | end function predict_graph2d | ||
| 486 | |||
| 487 | module function predict_array( this, input, verbose ) & | ||
| 488 | result(output) | ||
| 489 | !! Predict the output for a generic input | ||
| 490 | class(network_type), intent(inout) :: this | ||
| 491 | !! Instance of network | ||
| 492 | class(array_type), dimension(..), intent(in) :: input | ||
| 493 | !! Input graph | ||
| 494 | integer, intent(in), optional :: verbose | ||
| 495 | !! Verbosity level | ||
| 496 | type(array_type), dimension(:,:), allocatable :: output | ||
| 497 | end function predict_array | ||
| 498 | |||
| 499 | module function predict_generic( this, input, verbose, output_as_graph ) & | ||
| 500 | result(output) | ||
| 501 | !! Predict the output for a generic input | ||
| 502 | class(network_type), intent(inout) :: this | ||
| 503 | !! Instance of network | ||
| 504 | class(*), dimension(:,:), intent(in) :: input | ||
| 505 | !! Input graph | ||
| 506 | integer, intent(in), optional :: verbose | ||
| 507 | !! Verbosity level | ||
| 508 | logical, intent(in), optional :: output_as_graph | ||
| 509 | !! Boolean whether to output as graph | ||
| 510 | class(*), dimension(:,:), allocatable :: output | ||
| 511 | end function predict_generic | ||
| 512 | |||
| 513 | !! Interface for updating the learnable parameters of the network | ||
| 514 | !! based on gradients | ||
| 515 | module subroutine update(this) | ||
| 516 | !! Update the learnable parameters of the network based on gradients | ||
| 517 | class(network_type), intent(inout) :: this | ||
| 518 | !! Instance of the network | ||
| 519 | end subroutine update | ||
| 520 | |||
| 521 | !! Interface for generating vertex order | ||
| 522 | module subroutine build_vertex_order(this) | ||
| 523 | !! Generate vertex order | ||
| 524 | class(network_type), intent(inout) :: this | ||
| 525 | !! Instance of the network | ||
| 526 | end subroutine build_vertex_order | ||
| 527 | |||
| 528 | !! Interface for depth first search | ||
| 529 | recursive module subroutine dfs( & | ||
| 530 | this, vertex_index, visited, order, order_index & | ||
| 531 | ) | ||
| 532 | !! Depth first search | ||
| 533 | class(network_type), intent(in) :: this | ||
| 534 | !! Instance of the network | ||
| 535 | integer, intent(in) :: vertex_index | ||
| 536 | !! Vertex index | ||
| 537 | logical, dimension(this%auto_graph%num_vertices), intent(inout) :: & | ||
| 538 | visited | ||
| 539 | !! Visited vertices | ||
| 540 | integer, dimension(this%auto_graph%num_vertices), intent(inout) :: order | ||
| 541 | !! Order of vertices | ||
| 542 | integer, intent(inout) :: order_index | ||
| 543 | !! Index of order | ||
| 544 | end subroutine dfs | ||
| 545 | |||
| 546 | !! Interface for calculating root vertices | ||
| 547 | module subroutine build_root_vertices(this) | ||
| 548 | !! Calculate root vertices | ||
| 549 | class(network_type), intent(inout) :: this | ||
| 550 | !! Instance of the network | ||
| 551 | end subroutine build_root_vertices | ||
| 552 | |||
| 553 | !! Interface for calculating output vertices | ||
| 554 | module subroutine build_leaf_vertices(this) | ||
| 555 | !! Calculate output vertices | ||
| 556 | class(network_type), intent(inout) :: this | ||
| 557 | !! Instance of the network | ||
| 558 | end subroutine build_leaf_vertices | ||
| 559 | |||
| 560 | !! Interface for reducing two networks down to one | ||
| 561 | !! (i.e. add two networks - parallel) | ||
| 562 | module subroutine network_reduction(this, source) | ||
| 563 | !! Reduce two networks down to one (i.e. add two networks - parallel) | ||
| 564 | class(network_type), intent(inout) :: this | ||
| 565 | !! Instance of the network | ||
| 566 | type(network_type), intent(in) :: source | ||
| 567 | !! Source network | ||
| 568 | end subroutine network_reduction | ||
| 569 | |||
| 570 | !! Interface for copying a network | ||
| 571 | module subroutine network_copy(this, source) | ||
| 572 | !! Copy a network | ||
| 573 | class(network_type), intent(inout) :: this | ||
| 574 | !! Instance of the network | ||
| 575 | type(network_type), intent(in), target :: source | ||
| 576 | !! Source network | ||
| 577 | end subroutine network_copy | ||
| 578 | |||
| 579 | !! Interface for getting number of learnable parameters in the network | ||
| 580 | pure module function get_num_params(this) result(num_params) | ||
| 581 | !! Get number of learnable parameters in the network | ||
| 582 | class(network_type), intent(in) :: this | ||
| 583 | !! Instance of the network | ||
| 584 | integer :: num_params | ||
| 585 | !! Number of parameters | ||
| 586 | end function get_num_params | ||
| 587 | |||
| 588 | !! Interface for getting learnable parameters | ||
| 589 | pure module function get_params(this) result(params) | ||
| 590 | !! Get learnable parameters | ||
| 591 | class(network_type), intent(in) :: this | ||
| 592 | !! Instance of the network | ||
| 593 | real(real32), dimension(this%num_params) :: params | ||
| 594 | !! Learnable parameters | ||
| 595 | end function get_params | ||
| 596 | |||
| 597 | !! Interface for setting learnable parameters | ||
| 598 | module subroutine set_params(this, params) | ||
| 599 | !! Set learnable parameters | ||
| 600 | class(network_type), intent(inout) :: this | ||
| 601 | !! Instance of the network | ||
| 602 | real(real32), dimension(this%num_params), intent(in) :: params | ||
| 603 | !! Learnable parameters | ||
| 604 | end subroutine set_params | ||
| 605 | |||
| 606 | !! Interface for getting gradients of learnable parameters | ||
| 607 | pure module function get_gradients(this) result(gradients) | ||
| 608 | !! Get gradients of learnable parameters | ||
| 609 | class(network_type), intent(in) :: this | ||
| 610 | !! Instance of the network | ||
| 611 | real(real32), dimension(this%num_params) :: gradients | ||
| 612 | !! Gradients | ||
| 613 | end function get_gradients | ||
| 614 | |||
| 615 | !! Interface for setting learnable parameter gradients | ||
| 616 | module subroutine set_gradients(this, gradients) | ||
| 617 | !! Set learnable parameter gradients | ||
| 618 | class(network_type), intent(inout) :: this | ||
| 619 | !! Instance of the network | ||
| 620 | real(real32), dimension(..), intent(in) :: gradients | ||
| 621 | !! Gradients | ||
| 622 | end subroutine set_gradients | ||
| 623 | |||
| 624 | !! Interface for resetting learnable parameter gradients | ||
| 625 | module subroutine reset_gradients(this) | ||
| 626 | !! Reset learnable parameter gradients | ||
| 627 | class(network_type), intent(inout) :: this | ||
| 628 | !! Instance of the network | ||
| 629 | end subroutine reset_gradients | ||
| 630 | |||
| 631 | module function get_output(this) result(output) | ||
| 632 | class(network_type), intent(in) :: this | ||
| 633 | !! Instance of the network | ||
| 634 | type(array_type), dimension(:,:), allocatable :: output | ||
| 635 | !! Output | ||
| 636 | end function get_output | ||
| 637 | |||
| 638 | module function get_output_shape(this) result(output_shape) | ||
| 639 | class(network_type), intent(in) :: this | ||
| 640 | !! Instance of the network | ||
| 641 | integer, dimension(2) :: output_shape | ||
| 642 | !! Output shape | ||
| 643 | end function get_output_shape | ||
| 644 | |||
| 645 | module subroutine extract_output_real(this, output) | ||
| 646 | class(network_type), intent(in) :: this | ||
| 647 | !! Instance of network | ||
| 648 | real(real32), dimension(..), allocatable, intent(out) :: output | ||
| 649 | !! Output | ||
| 650 | end subroutine extract_output_real | ||
| 651 | |||
| 652 | module function accuracy_eval(this, output, start_index, end_index) & | ||
| 653 | result(accuracy) | ||
| 654 | !! Get the accuracy for the output | ||
| 655 | class(network_type), intent(in) :: this | ||
| 656 | !! Instance of network | ||
| 657 | class(*), dimension(:,:), intent(in) :: output | ||
| 658 | !! Output | ||
| 659 | integer, intent(in) :: start_index, end_index | ||
| 660 | !! Start and end batch indices | ||
| 661 | real(real32) :: accuracy | ||
| 662 | !! Accuracy value | ||
| 663 | end function accuracy_eval | ||
| 664 | |||
| 665 | module function loss_eval(this, start_index, end_index) result(loss) | ||
| 666 | !! Get the loss for the output | ||
| 667 | ! Arguments | ||
| 668 | class(network_type), intent(inout), target :: this | ||
| 669 | !! Instance of network | ||
| 670 | integer, intent(in) :: start_index, end_index | ||
| 671 | !! Start and end batch indices | ||
| 672 | |||
| 673 | type(array_type), pointer :: loss | ||
| 674 | end function loss_eval | ||
| 675 | |||
| 676 | !! Interface for forward pass | ||
| 677 | module subroutine forward_generic2d(this, input) | ||
| 678 | !! Forward pass for generic 2D input | ||
| 679 | class(network_type), intent(inout), target :: this | ||
| 680 | !! Instance of the network | ||
| 681 | class(*), dimension(:,:), intent(in) :: input | ||
| 682 | !! Input data | ||
| 683 | end subroutine forward_generic2d | ||
| 684 | |||
| 685 | module function forward_eval(this, input) result(output) | ||
| 686 | !! Forward pass evaluation | ||
| 687 | class(network_type), intent(inout), target :: this | ||
| 688 | !! Instance of the network | ||
| 689 | class(*), dimension(:,:), intent(in) :: input | ||
| 690 | !! Input data | ||
| 691 | type(array_type), pointer :: output(:,:) | ||
| 692 | !! Output data | ||
| 693 | end function forward_eval | ||
| 694 | |||
| 695 | module function forward_eval_multi(this, input) result(output) | ||
| 696 | !! Forward pass evaluation for multiple outputs | ||
| 697 | class(network_type), intent(inout), target :: this | ||
| 698 | !! Instance of the network | ||
| 699 | class(*), dimension(:,:), intent(in) :: input | ||
| 700 | !! Input data | ||
| 701 | type(array_ptr_type), pointer :: output(:) | ||
| 702 | !! Output data | ||
| 703 | end function forward_eval_multi | ||
| 704 | |||
| 705 | module subroutine nullify_graph(this) | ||
| 706 | !! Nullify graph data in the network to free memory | ||
| 707 | class(network_type), intent(inout) :: this | ||
| 708 | !! Instance of the network | ||
| 709 | end subroutine nullify_graph | ||
| 710 | end interface | ||
| 711 | |||
| 712 | interface get_sample | ||
| 713 | module function get_sample_ptr( & | ||
| 714 | input, start_index, end_index, batch_size & | ||
| 715 | ) result(sample_ptr) | ||
| 716 | !! Get a sample from a rank | ||
| 717 | implicit none | ||
| 718 | ! Arguments | ||
| 719 | integer, intent(in) :: start_index, end_index | ||
| 720 | !! Start and end indices | ||
| 721 | integer, intent(in) :: batch_size | ||
| 722 | !! Batch size | ||
| 723 | real(real32), dimension(..), intent(in), target :: input | ||
| 724 | !! Input array | ||
| 725 | ! Local variables | ||
| 726 | real(real32), pointer :: sample_ptr(:,:) | ||
| 727 | !! Pointer to sample | ||
| 728 | end function get_sample_ptr | ||
| 729 | module function get_sample_array( & | ||
| 730 | input, start_index, end_index, batch_size, as_graph& | ||
| 731 | ) result(sample) | ||
| 732 | !! Get sample for mixed input | ||
| 733 | integer, intent(in) :: start_index, end_index | ||
| 734 | !! Start and end indices | ||
| 735 | integer, intent(in) :: batch_size | ||
| 736 | !! Batch size | ||
| 737 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 738 | !! Input array | ||
| 739 | logical, intent(in) :: as_graph | ||
| 740 | !! Boolean whether to treat the input as a graph | ||
| 741 | type(array_type), dimension(:,:), allocatable :: sample | ||
| 742 | !! Sample array | ||
| 743 | end function get_sample_array | ||
| 744 | module function get_sample_graph1d( & | ||
| 745 | input, start_index, end_index, batch_size & | ||
| 746 | ) result(sample) | ||
| 747 | !! Get sample for graph input | ||
| 748 | integer, intent(in) :: start_index, end_index | ||
| 749 | !! Start and end indices | ||
| 750 | integer, intent(in) :: batch_size | ||
| 751 | !! Batch size | ||
| 752 | class(graph_type), dimension(:), intent(in) :: input | ||
| 753 | !! Input array | ||
| 754 | type(graph_type), dimension(1, batch_size) :: sample | ||
| 755 | !! Sample array | ||
| 756 | end function get_sample_graph1d | ||
| 757 | module function get_sample_graph2d( & | ||
| 758 | input, start_index, end_index, batch_size & | ||
| 759 | ) result(sample) | ||
| 760 | !! Get sample for graph input | ||
| 761 | integer, intent(in) :: start_index, end_index | ||
| 762 | !! Start and end indices | ||
| 763 | integer, intent(in) :: batch_size | ||
| 764 | !! Batch size | ||
| 765 | class(graph_type), dimension(:,:), intent(in) :: input | ||
| 766 | !! Input array | ||
| 767 | type(graph_type), dimension(size(input,1), batch_size) :: sample | ||
| 768 | !! Sample array | ||
| 769 | end function get_sample_graph2d | ||
| 770 | end interface get_sample | ||
| 771 | |||
| 772 | − | end module athena__network | |
| 773 |