| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__base_layer | ||
| 2 | !! Module containing the abstract base layer type | ||
| 3 | !! | ||
| 4 | !! This module contains the abstract base layer type, from which all other | ||
| 5 | !! layers are derived. The module also contains the abstract derived types | ||
| 6 | !! for the following layer types: | ||
| 7 | !! - padding | ||
| 8 | !! - pooling | ||
| 9 | !! - dropout | ||
| 10 | !! - learnable | ||
| 11 | !! - convolutional | ||
| 12 | !! - batch normalisation | ||
| 13 | !! | ||
| 14 | !! The following procedures are based on code from the neural-fortran library | ||
| 15 | !! https://github.com/modern-fortran/neural-fortran/blob/main/src/nf/nf_layer.f90 | ||
| 16 | use coreutils, only: real32 | ||
| 17 | use athena__clipper, only: clip_type | ||
| 18 | use athena__misc_types, only: base_actv_type, base_init_type, facets_type, & | ||
| 19 | onnx_attribute_type, onnx_node_type, onnx_initialiser_type, & | ||
| 20 | onnx_tensor_type | ||
| 21 | use diffstruc, only: array_type | ||
| 22 | use athena__diffstruc_extd, only: array_ptr_type | ||
| 23 | use graphstruc, only: graph_type | ||
| 24 | implicit none | ||
| 25 | |||
| 26 | private | ||
| 27 | |||
| 28 | public :: base_layer_type | ||
| 29 | public :: pad_layer_type | ||
| 30 | public :: pool_layer_type | ||
| 31 | public :: drop_layer_type | ||
| 32 | public :: learnable_layer_type | ||
| 33 | public :: conv_layer_type | ||
| 34 | public :: batch_layer_type | ||
| 35 | public :: merge_layer_type | ||
| 36 | |||
| 37 | !------------------------------------------------------------------------------- | ||
| 38 | ! layer abstract type | ||
| 39 | !------------------------------------------------------------------------------- | ||
| 40 | type, abstract :: base_layer_type | ||
| 41 | !! Type for base layer, from which all other layers are derived | ||
| 42 | integer :: id | ||
| 43 | !! Unique identifier | ||
| 44 | integer :: input_rank = 0 | ||
| 45 | !! Rank of input data | ||
| 46 | integer :: output_rank = 0 | ||
| 47 | !! Rank of output data | ||
| 48 | logical :: inference = .false. | ||
| 49 | !! Inference mode | ||
| 50 | logical :: use_graph_input = .false. | ||
| 51 | !! Use graph input | ||
| 52 | logical :: use_graph_output = .false. | ||
| 53 | !! Use graph output | ||
| 54 | character(:), allocatable :: name | ||
| 55 | !! Layer name | ||
| 56 | character(4) :: type = 'base' | ||
| 57 | !! Layer type | ||
| 58 | character(20) :: subtype = repeat(" ",20) | ||
| 59 | type(graph_type), allocatable, dimension(:) :: graph | ||
| 60 | !! Graph structure of input data | ||
| 61 | class(array_type), allocatable, dimension(:,:) :: output | ||
| 62 | !! Output | ||
| 63 | integer, allocatable, dimension(:) :: input_shape | ||
| 64 | !! Input shape | ||
| 65 | integer, allocatable, dimension(:) :: output_shape | ||
| 66 | !! Output shape | ||
| 67 | contains | ||
| 68 | procedure, pass(this) :: set_rank => set_rank_base | ||
| 69 | !! Set the input and output ranks of the layer | ||
| 70 | procedure, pass(this) :: set_shape => set_shape_base | ||
| 71 | !! Set the input shape of the layer | ||
| 72 | procedure, pass(this) :: get_num_params => get_num_params_base | ||
| 73 | !! Get the number of parameters in the layer | ||
| 74 | procedure, pass(this) :: print => print_base | ||
| 75 | !! Print the layer to a file with additional information | ||
| 76 | procedure, pass(this) :: print_to_unit => print_to_unit_base | ||
| 77 | !! Print the layer to a unit | ||
| 78 | procedure, pass(this) :: get_attributes => get_attributes_base | ||
| 79 | !! Get the attributes of the layer (for ONNX export) | ||
| 80 | procedure, pass(this) :: extract_output => extract_output_base | ||
| 81 | !! Extract the output of the layer as a standard real array | ||
| 82 | procedure(initialise), deferred, pass(this) :: init | ||
| 83 | !! Initialise the layer | ||
| 84 | |||
| 85 | procedure, pass(this) :: forward => forward_base | ||
| 86 | !! Forward pass of layer | ||
| 87 | procedure, pass(this) :: forward_eval => forward_eval_base | ||
| 88 | !! Forward pass of layer and return output for evaluation | ||
| 89 | |||
| 90 | procedure, pass(this) :: nullify_graph => nullify_graph_base | ||
| 91 | !! Nullify the forward pass data of the layer to free memory | ||
| 92 | |||
| 93 | |||
| 94 | !! Forward pass of layer using derived array_type | ||
| 95 | procedure(read_layer), deferred, pass(this) :: read | ||
| 96 | !! Read layer from file | ||
| 97 | procedure, pass(this) :: build_from_onnx => build_from_onnx_base | ||
| 98 | !! Build layer from ONNX node and initialiser | ||
| 99 | procedure, pass(this) :: set_graph => set_graph_base | ||
| 100 | !! Set the graph structure of the input data !! this is adjacency and edge weighting | ||
| 101 | procedure, pass(this) :: emit_onnx_nodes => emit_onnx_nodes_base | ||
| 102 | !! Emit ONNX JSON nodes for this layer (format-aware and polymorphic) | ||
| 103 | procedure, pass(this) :: emit_onnx_graph_inputs => & | ||
| 104 | emit_onnx_graph_inputs_base | ||
| 105 | !! Emit graph input tensor declarations for this layer | ||
| 106 | end type base_layer_type | ||
| 107 | |||
| 108 | interface | ||
| 109 | module subroutine print_base(this, file, unit, print_header_footer) | ||
| 110 | !! Print the layer to a file with additional information | ||
| 111 | class(base_layer_type), intent(in) :: this | ||
| 112 | !! Instance of the layer | ||
| 113 | character(*), optional, intent(in) :: file | ||
| 114 | !! File name | ||
| 115 | integer, optional, intent(in) :: unit | ||
| 116 | !! Unit number | ||
| 117 | logical, optional, intent(in) :: print_header_footer | ||
| 118 | !! Boolean whether to print header and footer | ||
| 119 | end subroutine print_base | ||
| 120 | |||
| 121 | module subroutine print_to_unit_base(this, unit) | ||
| 122 | !! Print the layer to a file | ||
| 123 | class(base_layer_type), intent(in) :: this | ||
| 124 | !! Instance of the layer | ||
| 125 | integer, intent(in) :: unit | ||
| 126 | !! File unit | ||
| 127 | end subroutine print_to_unit_base | ||
| 128 | |||
| 129 | module function get_attributes_base(this) result(attributes) | ||
| 130 | !! Get the attributes of the layer (for ONNX export) | ||
| 131 | class(base_layer_type), intent(in) :: this | ||
| 132 | !! Instance of the layer | ||
| 133 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 134 | !! Attributes of the layer | ||
| 135 | end function get_attributes_base | ||
| 136 | |||
| 137 | module subroutine set_rank_base(this, input_rank, output_rank) | ||
| 138 | !! Set the input and output ranks of the layer | ||
| 139 | class(base_layer_type), intent(inout) :: this | ||
| 140 | !! Instance of the layer | ||
| 141 | integer, intent(in) :: input_rank | ||
| 142 | !! Input rank | ||
| 143 | integer, intent(in) :: output_rank | ||
| 144 | !! Output rank | ||
| 145 | end subroutine set_rank_base | ||
| 146 | |||
| 147 | module subroutine set_shape_base(this, input_shape) | ||
| 148 | !! Set the input shape of the layer | ||
| 149 | class(base_layer_type), intent(inout) :: this | ||
| 150 | !! Instance of the layer | ||
| 151 | integer, dimension(:), intent(in) :: input_shape | ||
| 152 | !! Input shape | ||
| 153 | end subroutine set_shape_base | ||
| 154 | |||
| 155 | module subroutine extract_output_base(this, output) | ||
| 156 | !! Extract the output of the layer as a standard real array | ||
| 157 | class(base_layer_type), intent(in) :: this | ||
| 158 | !! Instance of the layer | ||
| 159 | real(real32), dimension(..), allocatable, intent(out) :: output | ||
| 160 | !! Output values | ||
| 161 | end subroutine extract_output_base | ||
| 162 | |||
| 163 | pure module function get_num_params_base(this) result(num_params) | ||
| 164 | class(base_layer_type), intent(in) :: this | ||
| 165 | integer :: num_params | ||
| 166 | end function get_num_params_base | ||
| 167 | end interface | ||
| 168 | |||
| 169 | |||
| 170 | interface | ||
| 171 | module subroutine initialise(this, input_shape, verbose) | ||
| 172 | !! Initialise the layer | ||
| 173 | class(base_layer_type), intent(inout) :: this | ||
| 174 | !! Instance of the layer | ||
| 175 | integer, dimension(:), intent(in) :: input_shape | ||
| 176 | !! Input shape | ||
| 177 | integer, optional, intent(in) :: verbose | ||
| 178 | !! Verbosity level | ||
| 179 | end subroutine initialise | ||
| 180 | end interface | ||
| 181 | |||
| 182 | interface | ||
| 183 | pure module function get_num_params(this) result(num_params) | ||
| 184 | !! Get number of parameters in layer | ||
| 185 | class(base_layer_type), intent(in) :: this | ||
| 186 | !! Instance of the layer | ||
| 187 | integer :: num_params | ||
| 188 | !! Number of parameters | ||
| 189 | end function get_num_params | ||
| 190 | end interface | ||
| 191 | |||
| 192 | interface | ||
| 193 | module subroutine forward_base(this, input) | ||
| 194 | !! Forward pass of layer | ||
| 195 | class(base_layer_type), intent(inout) :: this | ||
| 196 | !! Instance of the layer | ||
| 197 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 198 | !! Input data | ||
| 199 | end subroutine forward_base | ||
| 200 | |||
| 201 | module function forward_eval_base(this, input) result(output) | ||
| 202 | !! Forward pass of layer and return output for evaluation | ||
| 203 | class(base_layer_type), intent(inout), target :: this | ||
| 204 | !! Instance of the layer | ||
| 205 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 206 | !! Input data | ||
| 207 | type(array_type), pointer :: output(:,:) | ||
| 208 | !! Output data | ||
| 209 | end function forward_eval_base | ||
| 210 | |||
| 211 | module subroutine set_graph_base(this, graph) | ||
| 212 | !! Set the graph structure of the input data | ||
| 213 | class(base_layer_type), intent(inout) :: this | ||
| 214 | !! Instance of the layer | ||
| 215 | type(graph_type), dimension(:), intent(in) :: graph | ||
| 216 | !! Graph structure of input data | ||
| 217 | end subroutine set_graph_base | ||
| 218 | |||
| 219 | module subroutine nullify_graph_base(this) | ||
| 220 | !! Nullify the forward pass data of the layer to free memory | ||
| 221 | class(base_layer_type), intent(inout) :: this | ||
| 222 | !! Instance of the layer | ||
| 223 | end subroutine nullify_graph_base | ||
| 224 | end interface | ||
| 225 | |||
| 226 | interface | ||
| 227 | module subroutine read_layer(this, unit, verbose) | ||
| 228 | !! Read layer from file | ||
| 229 | class(base_layer_type), intent(inout) :: this | ||
| 230 | !! Instance of the layer | ||
| 231 | integer, intent(in) :: unit | ||
| 232 | !! File unit | ||
| 233 | integer, optional, intent(in) :: verbose | ||
| 234 | !! Verbosity level | ||
| 235 | end subroutine read_layer | ||
| 236 | |||
| 237 | module subroutine build_from_onnx_base( & | ||
| 238 | this, node, initialisers, value_info, verbose & | ||
| 239 | ) | ||
| 240 | !! Build layer from ONNX node | ||
| 241 | class(base_layer_type), intent(inout) :: this | ||
| 242 | !! Instance of the layer | ||
| 243 | type(onnx_node_type), intent(in) :: node | ||
| 244 | !! ONNX node | ||
| 245 | type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers | ||
| 246 | !! ONNX initialisers | ||
| 247 | type(onnx_tensor_type), dimension(:), intent(in) :: value_info | ||
| 248 | !! ONNX value info | ||
| 249 | integer, intent(in) :: verbose | ||
| 250 | !! Verbosity level | ||
| 251 | end subroutine build_from_onnx_base | ||
| 252 | |||
| 253 | module subroutine emit_onnx_nodes_base( & | ||
| 254 | this, prefix, & | ||
| 255 | nodes, num_nodes, max_nodes, & | ||
| 256 | inits, num_inits, max_inits, & | ||
| 257 | input_name, is_last_layer, format & | ||
| 258 | ) | ||
| 259 | !! Emit ONNX JSON nodes for this layer | ||
| 260 | !! Default implementation does nothing; override for GNN/NOP layers | ||
| 261 | class(base_layer_type), intent(in) :: this | ||
| 262 | !! Instance of the layer | ||
| 263 | character(*), intent(in) :: prefix | ||
| 264 | !! Node name prefix (e.g. "node_2") | ||
| 265 | type(onnx_node_type), intent(inout), dimension(:) :: nodes | ||
| 266 | !! Accumulator for ONNX nodes | ||
| 267 | integer, intent(inout) :: num_nodes | ||
| 268 | !! Current number of nodes | ||
| 269 | integer, intent(in) :: max_nodes | ||
| 270 | !! Maximum capacity | ||
| 271 | type(onnx_initialiser_type), intent(inout), dimension(:) :: inits | ||
| 272 | !! Accumulator for ONNX initialisers | ||
| 273 | integer, intent(inout) :: num_inits | ||
| 274 | !! Current number of initialisers | ||
| 275 | integer, intent(in) :: max_inits | ||
| 276 | !! Maximum capacity | ||
| 277 | character(*), optional, intent(in) :: input_name | ||
| 278 | !! Upstream tensor name used by sequential expanded ONNX format | ||
| 279 | logical, optional, intent(in) :: is_last_layer | ||
| 280 | !! Whether this is the last non-input layer in the network | ||
| 281 | integer, optional, intent(in) :: format | ||
| 282 | !! Export format selector | ||
| 283 | !! 1 = ONNX athena abstract format (default) | ||
| 284 | !! 2 = ONNX expanded format | ||
| 285 | end subroutine emit_onnx_nodes_base | ||
| 286 | |||
| 287 | module subroutine emit_onnx_graph_inputs_base( & | ||
| 288 | this, prefix, & | ||
| 289 | graph_inputs, num_inputs & | ||
| 290 | ) | ||
| 291 | !! Emit graph input tensor declarations for this layer | ||
| 292 | !! Default implementation does nothing; override for GNN layers | ||
| 293 | class(base_layer_type), intent(in) :: this | ||
| 294 | !! Instance of the layer | ||
| 295 | character(*), intent(in) :: prefix | ||
| 296 | !! Input name prefix (e.g. "input_1") | ||
| 297 | type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs | ||
| 298 | !! Accumulator for graph inputs | ||
| 299 | integer, intent(inout) :: num_inputs | ||
| 300 | !! Current number of inputs | ||
| 301 | end subroutine emit_onnx_graph_inputs_base | ||
| 302 | end interface | ||
| 303 | |||
| 304 | |||
| 305 | type, abstract, extends(base_layer_type) :: pad_layer_type | ||
| 306 | !! Type for padding layers | ||
| 307 | integer :: num_channels | ||
| 308 | !! Number of channels | ||
| 309 | integer :: imethod = 0 | ||
| 310 | !! Method for padding | ||
| 311 | integer, allocatable, dimension(:) :: pad | ||
| 312 | !! Padding size | ||
| 313 | character(len=20) :: method = 'valid' | ||
| 314 | !! Padding method | ||
| 315 | integer, allocatable, dimension(:,:) :: orig_bound, dest_bound | ||
| 316 | !! Original and destination bounds | ||
| 317 | type(facets_type), dimension(:), allocatable :: facets | ||
| 318 | !! Facets of the layer | ||
| 319 | contains | ||
| 320 | procedure, pass(this) :: init => init_pad | ||
| 321 | !! Initialise the layer | ||
| 322 | procedure, pass(this) :: print_to_unit => print_to_unit_pad | ||
| 323 | !! Print layer to unit | ||
| 324 | end type pad_layer_type | ||
| 325 | |||
| 326 | interface | ||
| 327 | module subroutine print_to_unit_pad(this, unit) | ||
| 328 | !! Print layer to unit | ||
| 329 | class(pad_layer_type), intent(in) :: this | ||
| 330 | !! Instance of the layer | ||
| 331 | integer, intent(in) :: unit | ||
| 332 | !! File unit | ||
| 333 | end subroutine print_to_unit_pad | ||
| 334 | |||
| 335 | module subroutine init_pad(this, input_shape, verbose) | ||
| 336 | class(pad_layer_type), intent(inout) :: this | ||
| 337 | integer, dimension(:), intent(in) :: input_shape | ||
| 338 | integer, optional, intent(in) :: verbose | ||
| 339 | end subroutine init_pad | ||
| 340 | end interface | ||
| 341 | |||
| 342 | |||
| 343 | type, abstract, extends(base_layer_type) :: pool_layer_type | ||
| 344 | !! Type for pooling layers | ||
| 345 | integer, allocatable, dimension(:) :: pool, strd | ||
| 346 | !! Pooling and stride sizes | ||
| 347 | integer :: num_channels | ||
| 348 | !! Number of channels | ||
| 349 | class(pad_layer_type), allocatable :: pad_layer | ||
| 350 | !! Padding layer | ||
| 351 | contains | ||
| 352 | procedure, pass(this) :: init => init_pool | ||
| 353 | !! Initialise the layer | ||
| 354 | procedure, pass(this) :: print_to_unit => print_to_unit_pool | ||
| 355 | !! Print layer to unit | ||
| 356 | procedure, pass(this) :: get_attributes => get_attributes_pool | ||
| 357 | !! Get the attributes of the layer (for ONNX export) | ||
| 358 | end type pool_layer_type | ||
| 359 | |||
| 360 | interface | ||
| 361 | module subroutine print_to_unit_pool(this, unit) | ||
| 362 | !! Print layer to unit | ||
| 363 | class(pool_layer_type), intent(in) :: this | ||
| 364 | !! Instance of the layer | ||
| 365 | integer, intent(in) :: unit | ||
| 366 | !! File unit | ||
| 367 | end subroutine print_to_unit_pool | ||
| 368 | |||
| 369 | module subroutine init_pool(this, input_shape, verbose) | ||
| 370 | class(pool_layer_type), intent(inout) :: this | ||
| 371 | integer, dimension(:), intent(in) :: input_shape | ||
| 372 | integer, optional, intent(in) :: verbose | ||
| 373 | end subroutine init_pool | ||
| 374 | |||
| 375 | module function get_attributes_pool(this) result(attributes) | ||
| 376 | !! Get the attributes of the layer (for ONNX export) | ||
| 377 | class(pool_layer_type), intent(in) :: this | ||
| 378 | !! Instance of the layer | ||
| 379 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 380 | !! Attributes of the layer | ||
| 381 | end function get_attributes_pool | ||
| 382 | end interface | ||
| 383 | |||
| 384 | |||
| 385 | type, abstract, extends(base_layer_type) :: drop_layer_type | ||
| 386 | !! Type for dropout layers | ||
| 387 | real(real32) :: rate = 0.1_real32 | ||
| 388 | !! Dropout rate, rate = 1 - keep_prob -- typical = 0.05-0.25 | ||
| 389 | contains | ||
| 390 | procedure(generate_mask), deferred, pass(this) :: generate_mask | ||
| 391 | !! Generate dropout mask | ||
| 392 | end type drop_layer_type | ||
| 393 | |||
| 394 | abstract interface | ||
| 395 | subroutine generate_mask(this) | ||
| 396 | !! Generate dropout mask | ||
| 397 | import :: drop_layer_type | ||
| 398 | class(drop_layer_type), intent(inout) :: this | ||
| 399 | !! Instance of the layer | ||
| 400 | end subroutine generate_mask | ||
| 401 | end interface | ||
| 402 | |||
| 403 | |||
| 404 | type, abstract, extends(base_layer_type) :: merge_layer_type | ||
| 405 | !! Type for merge layers (i.e. add, multiply, concatenate) | ||
| 406 | integer :: merge_mode = 1 | ||
| 407 | !! Integer code for fundamental merge method | ||
| 408 | !! 1 = pointwise | ||
| 409 | !! 2 = concatenate | ||
| 410 | !! 3 = reduction | ||
| 411 | !! 4 = parametric (NOT IMPLEMENTED) | ||
| 412 | character(len=20) :: method | ||
| 413 | !! Merge method | ||
| 414 | integer :: num_input_layers = 0 | ||
| 415 | !! Number of input layers | ||
| 416 | integer, allocatable, dimension(:) :: input_layer_ids | ||
| 417 | !! IDs of input layers | ||
| 418 | contains | ||
| 419 | procedure(combine_merge), deferred, pass(this) :: combine | ||
| 420 | !! Merge two layers (forward) | ||
| 421 | procedure(calc_input_shape), deferred, pass(this) :: calc_input_shape | ||
| 422 | !! Calculate input shape based on shapes of input layers | ||
| 423 | end type merge_layer_type | ||
| 424 | |||
| 425 | interface | ||
| 426 | module subroutine combine_merge(this, input_list) | ||
| 427 | !! Combine two layers (forward) | ||
| 428 | class(merge_layer_type), intent(inout) :: this | ||
| 429 | !! Instance of the layer | ||
| 430 | type(array_ptr_type), dimension(:), intent(in) :: input_list | ||
| 431 | !! Input values | ||
| 432 | end subroutine combine_merge | ||
| 433 | |||
| 434 | module function calc_input_shape(this, input_shapes) result(input_shape) | ||
| 435 | !! Calculate input shape based on shapes of input layers | ||
| 436 | class(merge_layer_type), intent(in) :: this | ||
| 437 | !! Instance of the layer | ||
| 438 | integer, dimension(:,:), intent(in) :: input_shapes | ||
| 439 | !! Input shapes | ||
| 440 | integer, allocatable, dimension(:) :: input_shape | ||
| 441 | !! Calculated input shape | ||
| 442 | end function calc_input_shape | ||
| 443 | end interface | ||
| 444 | |||
| 445 | type, abstract, extends(base_layer_type) :: learnable_layer_type | ||
| 446 | !! Type for layers with learnable parameters | ||
| 447 | integer :: num_params = 0 | ||
| 448 | !! Number of learnable parameters | ||
| 449 | logical :: use_bias = .false. | ||
| 450 | !! Layer has bias | ||
| 451 | integer, allocatable, dimension(:,:) :: weight_shape | ||
| 452 | !! Shape of weights | ||
| 453 | integer, allocatable, dimension(:) :: bias_shape | ||
| 454 | !! Shape of biases | ||
| 455 | type(array_type), allocatable, dimension(:) :: params | ||
| 456 | !! Learnable parameters | ||
| 457 | character(len=14) :: kernel_initialiser='', bias_initialiser='' | ||
| 458 | !! Initialisers for kernel and bias | ||
| 459 | class(base_init_type), allocatable :: kernel_init, bias_init | ||
| 460 | !! Initialisers for kernel and bias | ||
| 461 | class(base_actv_type), allocatable :: activation | ||
| 462 | !! Activation function | ||
| 463 | contains | ||
| 464 | procedure, pass(this) :: get_params => get_params | ||
| 465 | !! Get learnable parameters of layer | ||
| 466 | procedure, pass(this) :: set_params => set_params | ||
| 467 | !! Set learnable parameters of layer | ||
| 468 | procedure, pass(this) :: get_gradients => get_gradients | ||
| 469 | !! Get parameter gradients of layer | ||
| 470 | procedure, pass(this) :: set_gradients => set_gradients | ||
| 471 | !! Set learnable parameters of layer | ||
| 472 | |||
| 473 | procedure, pass(this) :: reduce => reduce_learnable | ||
| 474 | !! Merge another learnable layer into this one | ||
| 475 | procedure :: add_t_t => add_learnable | ||
| 476 | !! Add two layers | ||
| 477 | generic :: operator(+) => add_t_t | ||
| 478 | !! Operator overloading for addition | ||
| 479 | end type learnable_layer_type | ||
| 480 | |||
| 481 | interface | ||
| 482 | module subroutine reduce_learnable(this, input) | ||
| 483 | !! Merge another learnable layer into this one | ||
| 484 | class(learnable_layer_type), intent(inout) :: this | ||
| 485 | !! Instance of the layer | ||
| 486 | class(learnable_layer_type), intent(in) :: input | ||
| 487 | !! Other layer to merge | ||
| 488 | end subroutine reduce_learnable | ||
| 489 | |||
| 490 | module function add_learnable(a, b) result(output) | ||
| 491 | !! Add two layers | ||
| 492 | class(learnable_layer_type), intent(in) :: a, b | ||
| 493 | !! Instances of the layers | ||
| 494 | class(learnable_layer_type), allocatable :: output | ||
| 495 | !! Output layer | ||
| 496 | end function add_learnable | ||
| 497 | end interface | ||
| 498 | |||
| 499 | interface | ||
| 500 | pure module function get_params(this) result(params) | ||
| 501 | !! Get learnable parameters of layer | ||
| 502 | class(learnable_layer_type), intent(in) :: this | ||
| 503 | !! Instance of the layer | ||
| 504 | real(real32), dimension(this%num_params) :: params | ||
| 505 | !! Learnable parameters | ||
| 506 | end function get_params | ||
| 507 | |||
| 508 | module subroutine set_params(this, params) | ||
| 509 | !! Set learnable parameters of layer | ||
| 510 | class(learnable_layer_type), intent(inout) :: this | ||
| 511 | !! Instance of the layer | ||
| 512 | real(real32), dimension(this%num_params), intent(in) :: params | ||
| 513 | !! Learnable parameters | ||
| 514 | end subroutine set_params | ||
| 515 | |||
| 516 | pure module function get_gradients(this, clip_method) result(gradients) | ||
| 517 | !! Get parameter gradients of layer | ||
| 518 | class(learnable_layer_type), intent(in) :: this | ||
| 519 | !! Instance of the layer | ||
| 520 | type(clip_type), optional, intent(in) :: clip_method | ||
| 521 | !! Clip method | ||
| 522 | real(real32), dimension(this%num_params) :: gradients | ||
| 523 | !! Parameter gradients | ||
| 524 | end function get_gradients | ||
| 525 | |||
| 526 | module subroutine set_gradients(this, gradients) | ||
| 527 | !! Set learnable parameters of layer | ||
| 528 | class(learnable_layer_type), intent(inout) :: this | ||
| 529 | !! Instance of the layer | ||
| 530 | real(real32), dimension(..), intent(in) :: gradients | ||
| 531 | !! Learnable parameters | ||
| 532 | end subroutine set_gradients | ||
| 533 | end interface | ||
| 534 | |||
| 535 | type, abstract, extends(learnable_layer_type) :: conv_layer_type | ||
| 536 | integer :: num_channels | ||
| 537 | !! Number of channels | ||
| 538 | integer :: num_filters | ||
| 539 | !! Number of filters | ||
| 540 | integer, allocatable, dimension(:) :: knl, stp, dil | ||
| 541 | !! Kernel, stride, and dilation sizes | ||
| 542 | real(real32), pointer :: bias(:) => null() | ||
| 543 | !! Bias pointer | ||
| 544 | class(pad_layer_type), allocatable :: pad_layer | ||
| 545 | !! Optional preprocess padding layer | ||
| 546 | class(array_type), allocatable :: di_padded | ||
| 547 | !! Padded input gradients | ||
| 548 | type(array_type), dimension(2) :: z | ||
| 549 | !! Temporary arrays for forward propagation | ||
| 550 | contains | ||
| 551 | procedure, pass(this) :: get_num_params => get_num_params_conv | ||
| 552 | !! Get the number of parameters in the layer | ||
| 553 | procedure, pass(this) :: init => init_conv | ||
| 554 | !! Initialise the layer | ||
| 555 | procedure, pass(this) :: get_attributes => get_attributes_conv | ||
| 556 | !! Get the attributes of the layer (for ONNX export) | ||
| 557 | procedure, pass(this) :: print_to_unit => print_to_unit_conv | ||
| 558 | !! Print layer to unit | ||
| 559 | end type conv_layer_type | ||
| 560 | |||
| 561 | interface | ||
| 562 | pure module function get_num_params_conv(this) result(num_params) | ||
| 563 | class(conv_layer_type), intent(in) :: this | ||
| 564 | integer :: num_params | ||
| 565 | end function get_num_params_conv | ||
| 566 | |||
| 567 | module subroutine init_conv(this, input_shape, verbose) | ||
| 568 | class(conv_layer_type), intent(inout) :: this | ||
| 569 | integer, dimension(:), intent(in) :: input_shape | ||
| 570 | integer, optional, intent(in) :: verbose | ||
| 571 | end subroutine init_conv | ||
| 572 | |||
| 573 | module function get_attributes_conv(this) result(attributes) | ||
| 574 | class(conv_layer_type), intent(in) :: this | ||
| 575 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 576 | end function get_attributes_conv | ||
| 577 | |||
| 578 | module subroutine print_to_unit_conv(this, unit) | ||
| 579 | !! Print layer to unit | ||
| 580 | class(conv_layer_type), intent(in) :: this | ||
| 581 | !! Instance of the layer | ||
| 582 | integer, intent(in) :: unit | ||
| 583 | !! File unit | ||
| 584 | end subroutine print_to_unit_conv | ||
| 585 | end interface | ||
| 586 | |||
| 587 | type, abstract, extends(learnable_layer_type) :: batch_layer_type | ||
| 588 | !! Type for batch normalisation layers | ||
| 589 | integer :: num_channels | ||
| 590 | !! Number of channels | ||
| 591 | real(real32) :: momentum = 0.99_real32 | ||
| 592 | !! Momentum factor | ||
| 593 | !! NOTE: if momentum = 0, mean and variance batch-dependent values | ||
| 594 | !! NOTE: if momentum > 0, mean and variance are running averages | ||
| 595 | real(real32) :: epsilon = 0.001_real32 | ||
| 596 | !! Epsilon factor | ||
| 597 | real(real32) :: gamma_init_mean = 1._real32, gamma_init_std = 0.01_real32 | ||
| 598 | !! Initialisation parameters for gamma | ||
| 599 | real(real32) :: beta_init_mean = 0._real32, beta_init_std = 0.01_real32 | ||
| 600 | !! Initialisation parameters for beta | ||
| 601 | class(base_init_type), allocatable :: moving_mean_init, moving_variance_init | ||
| 602 | !! Initialisers for moving mean and variance | ||
| 603 | real(real32), allocatable, dimension(:) :: mean, variance | ||
| 604 | !! Mean and variance (not learnable) | ||
| 605 | contains | ||
| 606 | procedure, pass(this) :: get_num_params => get_num_params_batch | ||
| 607 | !! Get the number of parameters in the layer | ||
| 608 | procedure, pass(this) :: init => init_batch | ||
| 609 | !! Initialise the layer | ||
| 610 | procedure, pass(this) :: print_to_unit => print_to_unit_batch | ||
| 611 | !! Print layer to unit | ||
| 612 | procedure, pass(this) :: get_attributes => get_attributes_batch | ||
| 613 | !! Get the attributes of the layer (for ONNX export) | ||
| 614 | end type batch_layer_type | ||
| 615 | |||
| 616 | interface | ||
| 617 | |||
| 618 | pure module function get_num_params_batch(this) result(num_params) | ||
| 619 | class(batch_layer_type), intent(in) :: this | ||
| 620 | integer :: num_params | ||
| 621 | end function get_num_params_batch | ||
| 622 | |||
| 623 | module subroutine init_batch(this, input_shape, verbose) | ||
| 624 | class(batch_layer_type), intent(inout) :: this | ||
| 625 | integer, dimension(:), intent(in) :: input_shape | ||
| 626 | integer, optional, intent(in) :: verbose | ||
| 627 | end subroutine init_batch | ||
| 628 | |||
| 629 | module subroutine print_to_unit_batch(this, unit) | ||
| 630 | !! Print layer to unit | ||
| 631 | class(batch_layer_type), intent(in) :: this | ||
| 632 | !! Instance of the layer | ||
| 633 | integer, intent(in) :: unit | ||
| 634 | !! File unit | ||
| 635 | end subroutine print_to_unit_batch | ||
| 636 | |||
| 637 | module function get_attributes_batch(this) result(attributes) | ||
| 638 | class(batch_layer_type), intent(in) :: this | ||
| 639 | !! Instance of the layer | ||
| 640 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 641 | !! Attributes of the layer | ||
| 642 | end function get_attributes_batch | ||
| 643 | end interface | ||
| 644 | |||
| 645 | |||
| 646 |
54/158✗ Branch 0 not taken.
✓ Branch 1 taken 531 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 531 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 531 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 531 times.
✓ Branch 45 taken 531 times.
✓ Branch 46 taken 531 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 531 times.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
✓ Branch 67 taken 531 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 96 times.
✓ Branch 70 taken 435 times.
✓ Branch 71 taken 531 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 89 times.
✓ Branch 74 taken 442 times.
✓ Branch 75 taken 163 times.
✗ Branch 76 not taken.
✓ Branch 77 taken 89 times.
✓ Branch 78 taken 74 times.
✓ Branch 79 taken 163 times.
✗ Branch 80 not taken.
✓ Branch 81 taken 163 times.
✗ Branch 82 not taken.
✗ Branch 83 not taken.
✗ Branch 84 not taken.
✗ Branch 85 not taken.
✓ Branch 86 taken 163 times.
✗ Branch 87 not taken.
✓ Branch 88 taken 163 times.
✗ Branch 89 not taken.
✗ Branch 90 not taken.
✗ Branch 91 not taken.
✗ Branch 92 not taken.
✗ Branch 93 not taken.
✗ Branch 94 not taken.
✗ Branch 95 not taken.
✗ Branch 96 not taken.
✗ Branch 97 not taken.
✗ Branch 98 not taken.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✗ Branch 101 not taken.
✗ Branch 102 not taken.
✗ Branch 103 not taken.
✗ Branch 104 not taken.
✗ Branch 105 not taken.
✗ Branch 106 not taken.
✗ Branch 107 not taken.
✗ Branch 108 not taken.
✓ Branch 109 taken 368 times.
✗ Branch 110 not taken.
✓ Branch 111 taken 368 times.
✗ Branch 112 not taken.
✓ Branch 113 taken 320 times.
✓ Branch 114 taken 48 times.
✓ Branch 115 taken 48 times.
✗ Branch 116 not taken.
✓ Branch 118 taken 368 times.
✗ Branch 119 not taken.
✓ Branch 120 taken 320 times.
✓ Branch 121 taken 48 times.
✓ Branch 122 taken 368 times.
✗ Branch 123 not taken.
✓ Branch 124 taken 76 times.
✓ Branch 125 taken 292 times.
✓ Branch 126 taken 368 times.
✗ Branch 127 not taken.
✓ Branch 128 taken 76 times.
✓ Branch 129 taken 292 times.
✓ Branch 130 taken 368 times.
✗ Branch 131 not taken.
✓ Branch 132 taken 114 times.
✓ Branch 133 taken 254 times.
✓ Branch 134 taken 368 times.
✗ Branch 135 not taken.
✓ Branch 136 taken 76 times.
✓ Branch 137 taken 292 times.
✗ Branch 138 not taken.
✓ Branch 139 taken 368 times.
✗ Branch 140 not taken.
✓ Branch 141 taken 7 times.
✓ Branch 142 taken 361 times.
✓ Branch 143 taken 14 times.
✓ Branch 144 taken 347 times.
✗ Branch 145 not taken.
✓ Branch 146 taken 368 times.
✗ Branch 147 not taken.
✓ Branch 148 taken 7 times.
✓ Branch 149 taken 361 times.
✓ Branch 150 taken 14 times.
✓ Branch 151 taken 347 times.
✓ Branch 153 taken 368 times.
✗ Branch 154 not taken.
✓ Branch 155 taken 7 times.
✓ Branch 156 taken 361 times.
✗ Branch 157 not taken.
✓ Branch 158 taken 361 times.
✗ Branch 160 not taken.
✗ Branch 161 not taken.
✗ Branch 162 not taken.
✗ Branch 163 not taken.
✗ Branch 164 not taken.
✗ Branch 165 not taken.
✗ Branch 167 not taken.
✗ Branch 168 not taken.
✗ Branch 169 not taken.
✗ Branch 170 not taken.
✗ Branch 171 not taken.
✗ Branch 172 not taken.
✗ Branch 174 not taken.
✗ Branch 175 not taken.
✗ Branch 176 not taken.
✗ Branch 177 not taken.
✗ Branch 178 not taken.
✗ Branch 179 not taken.
✗ Branch 180 not taken.
✗ Branch 181 not taken.
|
2152 | end module athena__base_layer |
| 647 |