| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__misc_types | ||
| 2 | !! Module containing custom derived types and interfaces for ATHENA | ||
| 3 | !! | ||
| 4 | !! This module contains interfaces and derived types for | ||
| 5 | !! activation functions, initialisers, arrays, and facets. | ||
| 6 | !! The activation and initialiser types are abstract types that are used | ||
| 7 | !! to define the activation functions and initialisers for the | ||
| 8 | !! weights and biases in the neural network. The array type is an | ||
| 9 | !! abstract type that is used to define the operations that can be performed | ||
| 10 | !! on the arrays used in the neural network. The facets type is used to store | ||
| 11 | !! the faces, edges, and corners of the arrays for padding. | ||
| 12 | use coreutils, only: real32 | ||
| 13 | use diffstruc, only: array_type | ||
| 14 | implicit none | ||
| 15 | |||
| 16 | |||
| 17 | private | ||
| 18 | |||
| 19 | public :: base_actv_type | ||
| 20 | public :: base_init_type | ||
| 21 | public :: facets_type | ||
| 22 | public :: onnx_attribute_type, onnx_node_type, onnx_initialiser_type, & | ||
| 23 | onnx_tensor_type | ||
| 24 | |||
| 25 | |||
| 26 | |||
| 27 | |||
| 28 | |||
| 29 | |||
| 30 | !------------------------------------------------------------------------------- | ||
| 31 | ! Attributes type (for ONNX export) | ||
| 32 | !------------------------------------------------------------------------------- | ||
| 33 | type :: onnx_attribute_type | ||
| 34 | !! Type for storing attributes for ONNX export | ||
| 35 | character(64), allocatable :: name | ||
| 36 | !! Name of the attribute | ||
| 37 | character(10), allocatable :: type | ||
| 38 | !! Type of the attribute (e.g. 'int', 'float', 'string') | ||
| 39 | character(len=:), allocatable :: val | ||
| 40 | !! Value of the attribute as a string | ||
| 41 | !! This allows for flexible storage of different types | ||
| 42 | !! of attributes without needing to define a specific type | ||
| 43 | end type onnx_attribute_type | ||
| 44 | |||
| 45 | interface onnx_attribute_type | ||
| 46 | pure module function create_attribute(name, type, val) result(attribute) | ||
| 47 | !! Function to create an ONNX attribute | ||
| 48 | character(*), intent(in) :: name | ||
| 49 | !! Name of the attribute | ||
| 50 | character(*), intent(in) :: type | ||
| 51 | !! Type of the attribute | ||
| 52 | character(len=*), intent(in) :: val | ||
| 53 | !! Value of the attribute as a string | ||
| 54 | type(onnx_attribute_type) :: attribute | ||
| 55 | !! Resulting ONNX attribute | ||
| 56 | end function create_attribute | ||
| 57 | end interface | ||
| 58 | !------------------------------------------------------------------------------- | ||
| 59 | |||
| 60 | !------------------------------------------------------------------------------- | ||
| 61 | ! ONNX node type | ||
| 62 | !------------------------------------------------------------------------------- | ||
| 63 | type :: onnx_node_type | ||
| 64 | character(256) :: op_type | ||
| 65 | character(64) :: name | ||
| 66 | character(64), allocatable, dimension(:) :: inputs | ||
| 67 | character(64), allocatable, dimension(:) :: outputs | ||
| 68 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 69 | integer :: num_inputs, num_outputs | ||
| 70 | end type onnx_node_type | ||
| 71 | !------------------------------------------------------------------------------- | ||
| 72 | |||
| 73 | !------------------------------------------------------------------------------- | ||
| 74 | ! ONNX initialiser type | ||
| 75 | !------------------------------------------------------------------------------- | ||
| 76 | type :: onnx_initialiser_type | ||
| 77 | character(64) :: name | ||
| 78 | integer, allocatable, dimension(:) :: dims | ||
| 79 | real(real32), allocatable, dimension(:) :: data | ||
| 80 | end type onnx_initialiser_type | ||
| 81 | !------------------------------------------------------------------------------- | ||
| 82 | |||
| 83 | !------------------------------------------------------------------------------- | ||
| 84 | ! ONNX tensor type | ||
| 85 | !------------------------------------------------------------------------------- | ||
| 86 | type :: onnx_tensor_type | ||
| 87 | character(64) :: name | ||
| 88 | integer :: elem_type | ||
| 89 | integer, allocatable, dimension(:) :: dims | ||
| 90 | end type onnx_tensor_type | ||
| 91 | !------------------------------------------------------------------------------- | ||
| 92 | |||
| 93 | |||
| 94 | |||
| 95 | !------------------------------------------------------------------------------! | ||
| 96 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 97 | !------------------------------------------------------------------------------! | ||
| 98 | |||
| 99 | |||
| 100 | |||
| 101 | !------------------------------------------------------------------------------- | ||
| 102 | ! Activation (aka transfer) function base type | ||
| 103 | !------------------------------------------------------------------------------- | ||
| 104 | type, abstract :: base_actv_type | ||
| 105 | !! Abstract type for activation functions | ||
| 106 | character(10) :: name | ||
| 107 | !! Name of the activation function | ||
| 108 | real(real32) :: scale = 1._real32 | ||
| 109 | !! Scale of the activation function | ||
| 110 | real(real32) :: threshold | ||
| 111 | !! Threshold of the activation function | ||
| 112 | logical :: apply_scaling = .false. | ||
| 113 | !! Boolean to apply scaling or not | ||
| 114 | contains | ||
| 115 | procedure (apply_actv), deferred, pass(this) :: apply | ||
| 116 | !! Abstract procedure for 5D derivative of activation function | ||
| 117 | procedure(reset_actv), deferred, pass(this) :: reset | ||
| 118 | !! Reset activation function attributes and variables | ||
| 119 | procedure(apply_attributes_actv), deferred, pass(this) :: apply_attributes | ||
| 120 | !! Set up ONNX attributes | ||
| 121 | procedure(export_attributes_actv), deferred, pass(this) :: export_attributes | ||
| 122 | !! Export ONNX attributes | ||
| 123 | procedure, pass(this) :: print_to_unit => print_to_unit_actv | ||
| 124 | end type base_actv_type | ||
| 125 | |||
| 126 | ! Interface for activation function | ||
| 127 | !----------------------------------------------------------------------------- | ||
| 128 | abstract interface | ||
| 129 | subroutine reset_actv(this) | ||
| 130 | !! Interface for resetting activation function attributes and variables | ||
| 131 | import base_actv_type | ||
| 132 | class(base_actv_type), intent(inout) :: this | ||
| 133 | !! Instance of the activation type | ||
| 134 | end subroutine reset_actv | ||
| 135 | |||
| 136 | function apply_actv(this, val) result(output) | ||
| 137 | !! Interface for activation function | ||
| 138 | import base_actv_type, real32, array_type | ||
| 139 | class(base_actv_type), intent(in) :: this | ||
| 140 | type(array_type), intent(in) :: val | ||
| 141 | type(array_type), pointer :: output | ||
| 142 | end function apply_actv | ||
| 143 | |||
| 144 | subroutine apply_attributes_actv(this, attributes) | ||
| 145 | !! Interface for loading ONNX attributes | ||
| 146 | import base_actv_type, onnx_attribute_type | ||
| 147 | class(base_actv_type), intent(inout) :: this | ||
| 148 | !! Instance of the activation type | ||
| 149 | type(onnx_attribute_type), dimension(:), intent(in) :: attributes | ||
| 150 | !! ONNX attributes | ||
| 151 | end subroutine apply_attributes_actv | ||
| 152 | |||
| 153 | pure function export_attributes_actv(this) result(attributes) | ||
| 154 | !! Interface for exporting ONNX attributes | ||
| 155 | import base_actv_type, onnx_attribute_type | ||
| 156 | class(base_actv_type), intent(in) :: this | ||
| 157 | !! Instance of the activation type | ||
| 158 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 159 | end function export_attributes_actv | ||
| 160 | end interface | ||
| 161 | |||
| 162 | interface | ||
| 163 | module subroutine print_to_unit_actv(this, unit, identifier) | ||
| 164 | !! Interface for printing activation function details | ||
| 165 | class(base_actv_type), intent(in) :: this | ||
| 166 | !! Instance of the activation type | ||
| 167 | integer, intent(in) :: unit | ||
| 168 | !! Unit number for output | ||
| 169 | character(len=*), intent(in), optional :: identifier | ||
| 170 | !! Optional identifier for the activation function | ||
| 171 | end subroutine print_to_unit_actv | ||
| 172 | end interface | ||
| 173 | !------------------------------------------------------------------------------- | ||
| 174 | |||
| 175 | |||
| 176 | !------------------------------------------------------------------------------! | ||
| 177 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 178 | !------------------------------------------------------------------------------! | ||
| 179 | |||
| 180 | |||
| 181 | !------------------------------------------------------------------------------- | ||
| 182 | ! Weights and biases initialiser base type | ||
| 183 | !------------------------------------------------------------------------------- | ||
| 184 | type, abstract :: base_init_type | ||
| 185 | !! Abstract type for initialising weights and biases | ||
| 186 | character(len=20) :: name | ||
| 187 | !! Name of the initialiser | ||
| 188 | real(real32) :: scale = 1._real32, mean = 1._real32, std = 0.01_real32 | ||
| 189 | !! Scale, mean, and standard deviation of the initialiser | ||
| 190 | contains | ||
| 191 | procedure (initialiser_subroutine), deferred, pass(this) :: initialise | ||
| 192 | !! Abstract procedure for initialising weights and biases | ||
| 193 | end type base_init_type | ||
| 194 | |||
| 195 | ! Interface for initialiser function | ||
| 196 | !----------------------------------------------------------------------------- | ||
| 197 | abstract interface | ||
| 198 | !! Interface for initialiser function | ||
| 199 | subroutine initialiser_subroutine(this, input, fan_in, fan_out, spacing) | ||
| 200 | !! Interface for initialiser function | ||
| 201 | import base_init_type, real32 | ||
| 202 | class(base_init_type), intent(inout) :: this | ||
| 203 | !! Instance of the initialiser type | ||
| 204 | real(real32), dimension(..), intent(out) :: input | ||
| 205 | !! Array to initialise | ||
| 206 | integer, optional, intent(in) :: fan_in, fan_out | ||
| 207 | !! Number of input and output units | ||
| 208 | integer, dimension(:), optional, intent(in) :: spacing | ||
| 209 | !! Spacing of the array | ||
| 210 | end subroutine initialiser_subroutine | ||
| 211 | end interface | ||
| 212 | !------------------------------------------------------------------------------- | ||
| 213 | |||
| 214 | |||
| 215 | !------------------------------------------------------------------------------! | ||
| 216 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 217 | !------------------------------------------------------------------------------! | ||
| 218 | |||
| 219 | |||
| 220 | !------------------------------------------------------------------------------- | ||
| 221 | ! Facet type (for storing faces, edges, and corners for padding) | ||
| 222 | !------------------------------------------------------------------------------- | ||
| 223 | type :: facets_type | ||
| 224 | !! Type for storing faces, edges, and corners for padding | ||
| 225 | integer :: num | ||
| 226 | !! Number of facets | ||
| 227 | integer :: rank | ||
| 228 | !! Number of dimensions of the shape | ||
| 229 | integer :: nfixed_dims | ||
| 230 | !! Number of fixed dimensions | ||
| 231 | character(6) :: type | ||
| 232 | !! Type of facet, i.e. face, edge, corner | ||
| 233 | integer, dimension(:), allocatable :: dim | ||
| 234 | !! Dimension the facet is in, i.e. | ||
| 235 | integer, dimension(:,:,:), allocatable :: orig_bound | ||
| 236 | !! Original bounds of the facet (2, nfixed_dims, num) | ||
| 237 | integer, dimension(:,:,:), allocatable :: dest_bound | ||
| 238 | !! Destination bounds of the facet (2, nfixed_dims, num) | ||
| 239 | contains | ||
| 240 | procedure, pass(this) :: setup_bounds | ||
| 241 | !! Procedure for setting up bounds | ||
| 242 | end type facets_type | ||
| 243 | |||
| 244 | interface | ||
| 245 | !! Interface for setting up bounds | ||
| 246 | module subroutine setup_bounds(this, length, pad, imethod) | ||
| 247 | !! Procedure for setting up bounds | ||
| 248 | class(facets_type), intent(inout) :: this | ||
| 249 | !! Instance of the facets type | ||
| 250 | integer, dimension(this%rank), intent(in) :: length, pad | ||
| 251 | !! Length of the shape and padding | ||
| 252 | integer, intent(in) :: imethod | ||
| 253 | !! Method for setting up bounds | ||
| 254 | end subroutine setup_bounds | ||
| 255 | end interface | ||
| 256 | !------------------------------------------------------------------------------- | ||
| 257 | |||
| 258 | − | end module athena__misc_types | |
| 259 |