| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__neural_operator_layer | ||
| 2 | !! Module containing implementation of a simple neural operator layer | ||
| 3 | !! | ||
| 4 | !! This module implements a neural operator layer that approximates a | ||
| 5 | !! discretized integral operator. The layer combines a standard affine | ||
| 6 | !! transform (local component) with a mean-field integral operator | ||
| 7 | !! (global/non-local component): | ||
| 8 | !! | ||
| 9 | !! \[ \mathbf{v} = \sigma\!\left(\mathbf{W}\mathbf{u} | ||
| 10 | !! + \mathbf{w}_k \langle\mathbf{u}\rangle + \mathbf{b}\right) \] | ||
| 11 | !! | ||
| 12 | !! where: | ||
| 13 | !! - \(\mathbf{u} \in \mathbb{R}^{n_{in}}\) is the input (discretised function) | ||
| 14 | !! - \(\mathbf{W} \in \mathbb{R}^{n_{out} \times n_{in}}\) are the local weights | ||
| 15 | !! - \(\mathbf{w}_k \in \mathbb{R}^{n_{out}}\) are the integral kernel weights | ||
| 16 | !! - \(\langle\mathbf{u}\rangle = \frac{1}{n_{in}}\sum_j u_j\) is the input mean | ||
| 17 | !! - \(\mathbf{b} \in \mathbb{R}^{n_{out}}\) is the bias | ||
| 18 | !! - \(\sigma\) is the activation function | ||
| 19 | !! | ||
| 20 | !! The global mean \(\langle\mathbf{u}\rangle\) acts as a rank-1 approximation | ||
| 21 | !! to a continuous integral operator \(\mathcal{K}[u](x) = \kappa(x)\int u\,\mathrm{d}y\), | ||
| 22 | !! where \(\mathbf{w}_k\) discretises \(\kappa\). Using this layer stacked | ||
| 23 | !! in sequence provides a resolution-invariant building block similar to the | ||
| 24 | !! graph neural operator family. | ||
| 25 | !! | ||
| 26 | !! Number of parameters: | ||
| 27 | !! - with bias: \(n_{out}(n_{in} + 1) + n_{out} = n_{out}(n_{in}+2)\) | ||
| 28 | !! - without bias: \(n_{out}(n_{in} + 1)\) | ||
| 29 | use coreutils, only: real32, stop_program | ||
| 30 | use athena__base_layer, only: learnable_layer_type, base_layer_type | ||
| 31 | use athena__misc_types, only: base_actv_type, base_init_type, & | ||
| 32 | onnx_attribute_type, & | ||
| 33 | onnx_node_type, onnx_initialiser_type, onnx_tensor_type | ||
| 34 | use athena__onnx_nop_utils, only: emit_nop_input_transpose, & | ||
| 35 | emit_nop_output_tail, emit_float_initialiser, emit_matrix_initialiser | ||
| 36 | use diffstruc, only: array_type, matmul, mean, operator(+) | ||
| 37 | use athena__initialiser_data, only: data_init_type | ||
| 38 | implicit none | ||
| 39 | |||
| 40 | |||
| 41 | private | ||
| 42 | |||
| 43 | public :: neural_operator_layer_type | ||
| 44 | public :: read_neural_operator_layer | ||
| 45 | |||
| 46 | |||
| 47 | type, extends(learnable_layer_type) :: neural_operator_layer_type | ||
| 48 | !! Type for a neural operator layer | ||
| 49 | integer :: num_inputs | ||
| 50 | !! Number of inputs (discretisation points of the input function) | ||
| 51 | integer :: num_outputs | ||
| 52 | !! Number of outputs (discretisation points of the output function) | ||
| 53 | type(array_type), dimension(1) :: z | ||
| 54 | !! Temporary array for pre-activation values (forward propagation) | ||
| 55 | contains | ||
| 56 | procedure, pass(this) :: get_num_params => get_num_params_neural_operator | ||
| 57 | !! Get the number of parameters for the neural operator layer | ||
| 58 | procedure, pass(this) :: set_hyperparams => set_hyperparams_neural_operator | ||
| 59 | !! Set the hyperparameters for the neural operator layer | ||
| 60 | procedure, pass(this) :: init => init_neural_operator | ||
| 61 | !! Initialise the neural operator layer | ||
| 62 | procedure, pass(this) :: print_to_unit => print_to_unit_neural_operator | ||
| 63 | !! Print the layer to a file | ||
| 64 | procedure, pass(this) :: read => read_neural_operator | ||
| 65 | !! Read the layer from a file | ||
| 66 | |||
| 67 | procedure, pass(this) :: forward => forward_neural_operator | ||
| 68 | !! Forward propagation | ||
| 69 | procedure, pass(this) :: get_attributes => get_attributes_neural_operator | ||
| 70 | !! Get layer attributes for ONNX export | ||
| 71 | procedure, pass(this) :: emit_onnx_nodes => & | ||
| 72 | emit_onnx_nodes_neural_operator | ||
| 73 | !! Emit format-aware ONNX nodes for the layer | ||
| 74 | |||
| 75 | final :: finalise_neural_operator | ||
| 76 | !! Finalise neural operator layer | ||
| 77 | end type neural_operator_layer_type | ||
| 78 | |||
| 79 | interface neural_operator_layer_type | ||
| 80 | !! Interface for setting up the neural operator layer | ||
| 81 | module function layer_setup( & | ||
| 82 | num_outputs, num_inputs, use_bias, & | ||
| 83 | activation, & | ||
| 84 | kernel_initialiser, bias_initialiser, verbose & | ||
| 85 | ) result(layer) | ||
| 86 | !! Setup a neural operator layer | ||
| 87 | integer, intent(in) :: num_outputs | ||
| 88 | !! Number of outputs | ||
| 89 | integer, optional, intent(in) :: num_inputs | ||
| 90 | !! Number of inputs | ||
| 91 | logical, optional, intent(in) :: use_bias | ||
| 92 | !! Whether to use bias | ||
| 93 | class(*), optional, intent(in) :: activation | ||
| 94 | !! Activation function | ||
| 95 | class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser | ||
| 96 | !! Kernel and bias initialisers | ||
| 97 | integer, optional, intent(in) :: verbose | ||
| 98 | !! Verbosity level | ||
| 99 | type(neural_operator_layer_type) :: layer | ||
| 100 | !! Instance of the neural operator layer | ||
| 101 | end function layer_setup | ||
| 102 | end interface neural_operator_layer_type | ||
| 103 | |||
| 104 | |||
| 105 | |||
| 106 | contains | ||
| 107 | |||
| 108 | !############################################################################### | ||
| 109 | 26 | subroutine finalise_neural_operator(this) | |
| 110 | !! Finalise neural operator layer | ||
| 111 | implicit none | ||
| 112 | |||
| 113 | ! Arguments | ||
| 114 | type(neural_operator_layer_type), intent(inout) :: this | ||
| 115 | !! Instance of the neural operator layer | ||
| 116 | |||
| 117 |
3/4✓ Branch 0 taken 24 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
|
26 | if(allocated(this%input_shape)) deallocate(this%input_shape) |
| 118 |
4/6✓ Branch 0 taken 24 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 24 times.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
|
26 | if(allocated(this%output)) deallocate(this%output) |
| 119 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 23 times.
|
26 | if(this%z(1)%allocated) call this%z(1)%deallocate() |
| 120 | |||
| 121 | 26 | end subroutine finalise_neural_operator | |
| 122 | !############################################################################### | ||
| 123 | |||
| 124 | |||
| 125 | !##############################################################################! | ||
| 126 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 127 | !##############################################################################! | ||
| 128 | |||
| 129 | |||
| 130 | !############################################################################### | ||
| 131 | 12 | pure function get_num_params_neural_operator(this) result(num_params) | |
| 132 | !! Get the number of parameters for the neural operator layer | ||
| 133 | !! | ||
| 134 | !! Parameters consist of: | ||
| 135 | !! - W matrix : num_outputs * num_inputs | ||
| 136 | !! - W_k vector : num_outputs (integral kernel coupling) | ||
| 137 | !! - b vector : num_outputs (if use_bias) | ||
| 138 | implicit none | ||
| 139 | |||
| 140 | ! Arguments | ||
| 141 | class(neural_operator_layer_type), intent(in) :: this | ||
| 142 | !! Instance of the neural operator layer | ||
| 143 | integer :: num_params | ||
| 144 | !! Number of parameters | ||
| 145 | |||
| 146 | ! W: n_out * n_in, W_k: n_out, b: n_out (if use_bias) | ||
| 147 | 12 | num_params = this%num_outputs * this%num_inputs + this%num_outputs | |
| 148 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 2 times.
|
12 | if(this%use_bias) num_params = num_params + this%num_outputs |
| 149 | |||
| 150 | 12 | end function get_num_params_neural_operator | |
| 151 | !############################################################################### | ||
| 152 | |||
| 153 | |||
| 154 | !##############################################################################! | ||
| 155 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 156 | !##############################################################################! | ||
| 157 | |||
| 158 | |||
| 159 | !############################################################################### | ||
| 160 | 12 | module function layer_setup( & | |
| 161 | num_outputs, num_inputs, & | ||
| 162 | use_bias, & | ||
| 163 | activation, & | ||
| 164 | kernel_initialiser, bias_initialiser, verbose & | ||
| 165 |
9/16✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 12 times.
|
24 | ) result(layer) |
| 166 | !! Setup a neural operator layer | ||
| 167 | use athena__activation, only: activation_setup | ||
| 168 | use athena__initialiser, only: initialiser_setup | ||
| 169 | implicit none | ||
| 170 | |||
| 171 | ! Arguments | ||
| 172 | integer, intent(in) :: num_outputs | ||
| 173 | !! Number of outputs | ||
| 174 | integer, optional, intent(in) :: num_inputs | ||
| 175 | !! Number of inputs | ||
| 176 | logical, optional, intent(in) :: use_bias | ||
| 177 | !! Whether to use bias | ||
| 178 | class(*), optional, intent(in) :: activation | ||
| 179 | !! Activation function | ||
| 180 | class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser | ||
| 181 | !! Kernel and bias initialisers | ||
| 182 | integer, optional, intent(in) :: verbose | ||
| 183 | !! Verbosity level | ||
| 184 | |||
| 185 | type(neural_operator_layer_type) :: layer | ||
| 186 | !! Instance of the neural operator layer | ||
| 187 | |||
| 188 | ! Local variables | ||
| 189 | integer :: verbose_ = 0 | ||
| 190 | !! Verbosity level | ||
| 191 | logical :: use_bias_ = .true. | ||
| 192 | !! Whether to use bias | ||
| 193 | 36 | class(base_actv_type), allocatable :: activation_ | |
| 194 | !! Activation function | ||
| 195 | 12 | class(base_init_type), allocatable :: kernel_initialiser_, bias_initialiser_ | |
| 196 | !! Kernel and bias initialisers | ||
| 197 | |||
| 198 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if(present(verbose)) verbose_ = verbose |
| 199 | |||
| 200 | |||
| 201 | !--------------------------------------------------------------------------- | ||
| 202 | ! Set use_bias | ||
| 203 | !--------------------------------------------------------------------------- | ||
| 204 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
|
12 | if(present(use_bias)) use_bias_ = use_bias |
| 205 | |||
| 206 | |||
| 207 | !--------------------------------------------------------------------------- | ||
| 208 | ! Set activation function | ||
| 209 | !--------------------------------------------------------------------------- | ||
| 210 |
3/4✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
|
12 | if(present(activation))then |
| 211 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 8 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 8 times.
✓ Branch 17 taken 8 times.
✗ Branch 18 not taken.
|
8 | activation_ = activation_setup(activation) |
| 212 | else | ||
| 213 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4 times.
✓ Branch 17 taken 4 times.
✗ Branch 18 not taken.
|
4 | activation_ = activation_setup("none") |
| 214 | end if | ||
| 215 | |||
| 216 | |||
| 217 | !--------------------------------------------------------------------------- | ||
| 218 | ! Define kernel and bias initialisers | ||
| 219 | !--------------------------------------------------------------------------- | ||
| 220 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
12 | if(present(kernel_initialiser))then |
| 221 | ✗ | kernel_initialiser_ = initialiser_setup(kernel_initialiser) | |
| 222 | end if | ||
| 223 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
12 | if(present(bias_initialiser))then |
| 224 | ✗ | bias_initialiser_ = initialiser_setup(bias_initialiser) | |
| 225 | end if | ||
| 226 | |||
| 227 | |||
| 228 | !--------------------------------------------------------------------------- | ||
| 229 | ! Set hyperparameters | ||
| 230 | !--------------------------------------------------------------------------- | ||
| 231 | call layer%set_hyperparams( & | ||
| 232 | num_outputs = num_outputs, & | ||
| 233 | use_bias = use_bias_, & | ||
| 234 | activation = activation_, & | ||
| 235 | kernel_initialiser = kernel_initialiser_, & | ||
| 236 | bias_initialiser = bias_initialiser_, & | ||
| 237 | verbose = verbose_ & | ||
| 238 | 12 | ) | |
| 239 | |||
| 240 | |||
| 241 | !--------------------------------------------------------------------------- | ||
| 242 | ! Initialise layer shape if num_inputs is provided | ||
| 243 | !--------------------------------------------------------------------------- | ||
| 244 |
4/4✓ Branch 0 taken 10 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 10 times.
✓ Branch 3 taken 10 times.
|
22 | if(present(num_inputs)) call layer%init(input_shape=[num_inputs]) |
| 245 | |||
| 246 |
13/28✓ Branch 0 taken 12 times.
✓ Branch 1 taken 12 times.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 12 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 12 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 26 taken 12 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 29 taken 12 times.
|
36 | end function layer_setup |
| 247 | !############################################################################### | ||
| 248 | |||
| 249 | |||
| 250 | !############################################################################### | ||
| 251 | 13 | subroutine set_hyperparams_neural_operator( & | |
| 252 | this, num_outputs, & | ||
| 253 | use_bias, & | ||
| 254 | activation, & | ||
| 255 | kernel_initialiser, bias_initialiser, & | ||
| 256 | verbose & | ||
| 257 | ) | ||
| 258 | !! Set the hyperparameters for the neural operator layer | ||
| 259 | use athena__activation, only: activation_setup | ||
| 260 | use athena__initialiser, only: get_default_initialiser, initialiser_setup | ||
| 261 | implicit none | ||
| 262 | |||
| 263 | ! Arguments | ||
| 264 | class(neural_operator_layer_type), intent(inout) :: this | ||
| 265 | !! Instance of the neural operator layer | ||
| 266 | integer, intent(in) :: num_outputs | ||
| 267 | !! Number of outputs | ||
| 268 | logical, intent(in) :: use_bias | ||
| 269 | !! Whether to use bias | ||
| 270 | class(base_actv_type), allocatable, intent(in) :: activation | ||
| 271 | !! Activation function | ||
| 272 | class(base_init_type), allocatable, intent(in) :: & | ||
| 273 | kernel_initialiser, bias_initialiser | ||
| 274 | !! Kernel and bias initialisers | ||
| 275 | integer, optional, intent(in) :: verbose | ||
| 276 | !! Verbosity level | ||
| 277 | |||
| 278 | ! Local variables | ||
| 279 | character(len=256) :: buffer | ||
| 280 | |||
| 281 | |||
| 282 |
5/8✓ Branch 0 taken 12 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 13 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
|
13 | this%name = "neural_operator" |
| 283 | 13 | this%type = "nop" | |
| 284 | 13 | this%input_rank = 1 | |
| 285 | 13 | this%output_rank = 1 | |
| 286 | 13 | this%use_bias = use_bias | |
| 287 | 13 | this%num_outputs = num_outputs | |
| 288 |
4/6✓ Branch 0 taken 1 times.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
|
13 | if(allocated(this%activation)) deallocate(this%activation) |
| 289 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
|
13 | if(.not.allocated(activation))then |
| 290 | ✗ | this%activation = activation_setup("none") | |
| 291 | else | ||
| 292 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
|
13 | allocate(this%activation, source=activation) |
| 293 | end if | ||
| 294 |
4/6✓ Branch 0 taken 1 times.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
|
13 | if(allocated(this%kernel_init)) deallocate(this%kernel_init) |
| 295 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 1 times.
|
13 | if(.not.allocated(kernel_initialiser))then |
| 296 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | buffer = get_default_initialiser(this%activation%name) |
| 297 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 12 times.
✓ Branch 17 taken 12 times.
✗ Branch 18 not taken.
|
12 | this%kernel_init = initialiser_setup(buffer) |
| 298 | else | ||
| 299 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(this%kernel_init, source=kernel_initialiser) |
| 300 | end if | ||
| 301 |
4/6✓ Branch 0 taken 1 times.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
|
13 | if(allocated(this%bias_init)) deallocate(this%bias_init) |
| 302 |
2/2✓ Branch 0 taken 12 times.
✓ Branch 1 taken 1 times.
|
13 | if(.not.allocated(bias_initialiser))then |
| 303 | buffer = get_default_initialiser( & | ||
| 304 | this%activation%name, & | ||
| 305 | is_bias=.true. & | ||
| 306 |
1/2✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
|
12 | ) |
| 307 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 12 times.
✓ Branch 17 taken 12 times.
✗ Branch 18 not taken.
|
12 | this%bias_init = initialiser_setup(buffer) |
| 308 | else | ||
| 309 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(this%bias_init)) deallocate(this%bias_init) |
| 310 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(this%bias_init, source=bias_initialiser) |
| 311 | end if | ||
| 312 |
1/2✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
|
13 | if(present(verbose))then |
| 313 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
|
13 | if(abs(verbose).gt.0)then |
| 314 | write(*,'("NEURAL_OPERATOR activation function: ",A)') & | ||
| 315 | ✗ | trim(this%activation%name) | |
| 316 | write(*,'("NEURAL_OPERATOR kernel initialiser: ",A)') & | ||
| 317 | ✗ | trim(this%kernel_init%name) | |
| 318 | write(*,'("NEURAL_OPERATOR bias initialiser: ",A)') & | ||
| 319 | ✗ | trim(this%bias_init%name) | |
| 320 | end if | ||
| 321 | end if | ||
| 322 | |||
| 323 | 13 | end subroutine set_hyperparams_neural_operator | |
| 324 | !############################################################################### | ||
| 325 | |||
| 326 | |||
| 327 | !############################################################################### | ||
| 328 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | subroutine init_neural_operator(this, input_shape, verbose) |
| 329 | !! Initialise neural operator layer | ||
| 330 | implicit none | ||
| 331 | |||
| 332 | ! Arguments | ||
| 333 | class(neural_operator_layer_type), intent(inout) :: this | ||
| 334 | !! Instance of the neural operator layer | ||
| 335 | integer, dimension(:), intent(in) :: input_shape | ||
| 336 | !! Input shape | ||
| 337 | integer, optional, intent(in) :: verbose | ||
| 338 | !! Verbosity level | ||
| 339 | |||
| 340 | ! Local variables | ||
| 341 | integer :: num_inputs | ||
| 342 | !! Effective fan-in for initialisation | ||
| 343 | integer :: verbose_ = 0 | ||
| 344 | |||
| 345 | |||
| 346 | !--------------------------------------------------------------------------- | ||
| 347 | ! Initialise optional arguments | ||
| 348 | !--------------------------------------------------------------------------- | ||
| 349 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if(present(verbose)) verbose_ = verbose |
| 350 | |||
| 351 | |||
| 352 | !--------------------------------------------------------------------------- | ||
| 353 | ! Initialise number of inputs | ||
| 354 | !--------------------------------------------------------------------------- | ||
| 355 |
4/8✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
|
12 | if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) |
| 356 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%num_inputs = this%input_shape(1) |
| 357 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12 times.
✓ Branch 7 taken 12 times.
|
24 | this%output_shape = [this%num_outputs] |
| 358 | 12 | this%num_params = this%get_num_params() | |
| 359 | |||
| 360 | |||
| 361 | !--------------------------------------------------------------------------- | ||
| 362 | ! Allocate parameters | ||
| 363 | ! | ||
| 364 | ! params(1): W (n_out x n_in) - local transform weights | ||
| 365 | ! params(2): W_k (n_out x 1) - integral kernel coupling weights | ||
| 366 | ! params(3): b (n_out) - bias [only when use_bias=.true.] | ||
| 367 | !--------------------------------------------------------------------------- | ||
| 368 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
|
12 | allocate(this%weight_shape(2,1)) |
| 369 |
9/16✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✓ Branch 21 taken 24 times.
✓ Branch 22 taken 12 times.
|
36 | this%weight_shape(:,1) = [ this%num_outputs, this%num_inputs ] |
| 370 | |||
| 371 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 2 times.
|
12 | if(this%use_bias)then |
| 372 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
✓ Branch 7 taken 10 times.
|
20 | this%bias_shape = [ this%num_outputs ] |
| 373 |
16/30✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✓ Branch 21 taken 30 times.
✓ Branch 22 taken 10 times.
✓ Branch 23 taken 30 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 30 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 30 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 30 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 30 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 30 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 30 times.
|
40 | allocate(this%params(3)) |
| 374 | else | ||
| 375 |
16/30✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 4 times.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 4 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 4 times.
|
6 | allocate(this%params(2)) |
| 376 | end if | ||
| 377 | |||
| 378 | ! W: local transform (n_out x n_in) | ||
| 379 |
14/24✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 12 times.
✓ Branch 27 taken 24 times.
✓ Branch 28 taken 12 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 12 times.
✓ Branch 31 taken 36 times.
✓ Branch 32 taken 12 times.
|
72 | call this%params(1)%allocate([this%weight_shape(:,1), 1]) |
| 380 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | call this%params(1)%set_requires_grad(.true.) |
| 381 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(1)%fix_pointer = .true. |
| 382 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(1)%is_sample_dependent = .false. |
| 383 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(1)%is_temporary = .false. |
| 384 | |||
| 385 | ! W_k: integral kernel coupling (n_out x 1) | ||
| 386 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✓ Branch 6 taken 36 times.
✓ Branch 7 taken 12 times.
|
48 | call this%params(2)%allocate([this%num_outputs, 1, 1]) |
| 387 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | call this%params(2)%set_requires_grad(.true.) |
| 388 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(2)%fix_pointer = .true. |
| 389 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(2)%is_sample_dependent = .false. |
| 390 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
|
12 | this%params(2)%is_temporary = .false. |
| 391 | |||
| 392 | 12 | num_inputs = this%num_inputs | |
| 393 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 2 times.
|
12 | if(this%use_bias)then |
| 394 | 10 | num_inputs = this%num_inputs + 1 | |
| 395 |
12/20✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✓ Branch 21 taken 10 times.
✓ Branch 22 taken 10 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 10 times.
✓ Branch 25 taken 20 times.
✓ Branch 26 taken 10 times.
|
40 | call this%params(3)%allocate([this%bias_shape, 1]) |
| 396 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
|
10 | call this%params(3)%set_requires_grad(.true.) |
| 397 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
|
10 | this%params(3)%fix_pointer = .true. |
| 398 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
|
10 | this%params(3)%is_sample_dependent = .false. |
| 399 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
|
10 | this%params(3)%is_temporary = .false. |
| 400 | end if | ||
| 401 | |||
| 402 | |||
| 403 | !--------------------------------------------------------------------------- | ||
| 404 | ! Initialise W with kernel initialiser | ||
| 405 | !--------------------------------------------------------------------------- | ||
| 406 | call this%kernel_init%initialise( & | ||
| 407 | 120 | this%params(1)%val(:,1), & | |
| 408 | fan_in = num_inputs, fan_out = this%num_outputs, & | ||
| 409 | spacing = [ this%num_outputs ] & | ||
| 410 |
12/22✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 12 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 12 times.
✓ Branch 30 taken 12 times.
✓ Branch 31 taken 12 times.
|
24 | ) |
| 411 | |||
| 412 | !--------------------------------------------------------------------------- | ||
| 413 | ! Initialise W_k with kernel initialiser (smaller scale), treating it as | ||
| 414 | ! a rank-1 integral correction so fan_in=1 | ||
| 415 | !--------------------------------------------------------------------------- | ||
| 416 | call this%kernel_init%initialise( & | ||
| 417 | 120 | this%params(2)%val(:,1), & | |
| 418 | fan_in = num_inputs, fan_out = this%num_outputs, & | ||
| 419 | spacing = [ this%num_outputs ] & | ||
| 420 |
12/22✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 12 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 12 times.
✓ Branch 30 taken 12 times.
✓ Branch 31 taken 12 times.
|
24 | ) |
| 421 | |||
| 422 | !--------------------------------------------------------------------------- | ||
| 423 | ! Initialise bias if used | ||
| 424 | !--------------------------------------------------------------------------- | ||
| 425 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 2 times.
|
12 | if(this%use_bias)then |
| 426 | call this%bias_init%initialise( & | ||
| 427 | 100 | this%params(3)%val(:,1), & | |
| 428 | fan_in = num_inputs, fan_out = this%num_outputs & | ||
| 429 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 10 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 10 times.
|
10 | ) |
| 430 | end if | ||
| 431 | |||
| 432 | |||
| 433 | !--------------------------------------------------------------------------- | ||
| 434 | ! Allocate output and pre-activation arrays | ||
| 435 | !--------------------------------------------------------------------------- | ||
| 436 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
12 | if(allocated(this%output)) deallocate(this%output) |
| 437 |
15/26✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 12 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 12 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 12 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 12 times.
✓ Branch 33 taken 12 times.
✓ Branch 34 taken 12 times.
✓ Branch 35 taken 12 times.
✓ Branch 36 taken 12 times.
|
36 | allocate(this%output(1,1)) |
| 438 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if(this%z(1)%allocated) call this%z(1)%deallocate() |
| 439 | |||
| 440 | 12 | end subroutine init_neural_operator | |
| 441 | !############################################################################### | ||
| 442 | |||
| 443 | |||
| 444 | !##############################################################################! | ||
| 445 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 446 | !##############################################################################! | ||
| 447 | |||
| 448 | |||
| 449 | !############################################################################### | ||
| 450 | 1 | subroutine print_to_unit_neural_operator(this, unit) | |
| 451 | !! Print neural operator layer to unit | ||
| 452 | use coreutils, only: to_upper | ||
| 453 | implicit none | ||
| 454 | |||
| 455 | ! Arguments | ||
| 456 | class(neural_operator_layer_type), intent(in) :: this | ||
| 457 | !! Instance of the neural operator layer | ||
| 458 | integer, intent(in) :: unit | ||
| 459 | !! File unit | ||
| 460 | |||
| 461 | |||
| 462 | ! Write hyperparameters | ||
| 463 | !--------------------------------------------------------------------------- | ||
| 464 | 1 | write(unit,'(3X,"NUM_INPUTS = ",I0)') this%num_inputs | |
| 465 | 1 | write(unit,'(3X,"NUM_OUTPUTS = ",I0)') this%num_outputs | |
| 466 | |||
| 467 | 1 | write(unit,'(3X,"USE_BIAS = ",L1)') this%use_bias | |
| 468 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%activation%name .ne. 'none')then |
| 469 | 1 | call this%activation%print_to_unit(unit) | |
| 470 | end if | ||
| 471 | |||
| 472 | |||
| 473 | ! Write weights, kernel coupling, and optional bias | ||
| 474 | !--------------------------------------------------------------------------- | ||
| 475 | 1 | write(unit,'("WEIGHTS")') | |
| 476 |
10/18✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 12 times.
✓ Branch 26 taken 1 times.
|
13 | write(unit,'(5(E16.8E2))') this%params(1)%val(:,1) ! W |
| 477 |
10/18✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 4 times.
✓ Branch 26 taken 1 times.
|
5 | write(unit,'(5(E16.8E2))') this%params(2)%val(:,1) ! W_k |
| 478 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%use_bias)then |
| 479 |
10/18✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 4 times.
✓ Branch 26 taken 1 times.
|
5 | write(unit,'(5(E16.8E2))') this%params(3)%val(:,1) ! b |
| 480 | end if | ||
| 481 | 1 | write(unit,'("END WEIGHTS")') | |
| 482 | |||
| 483 | 1 | end subroutine print_to_unit_neural_operator | |
| 484 | !############################################################################### | ||
| 485 | |||
| 486 | |||
| 487 | !############################################################################### | ||
| 488 | 1 | subroutine read_neural_operator(this, unit, verbose) | |
| 489 | !! Read neural operator layer from file | ||
| 490 | use athena__tools_infile, only: assign_val, assign_vec, move | ||
| 491 | use coreutils, only: to_lower, to_upper, icount | ||
| 492 | use athena__activation, only: read_activation | ||
| 493 | use athena__initialiser, only: initialiser_setup | ||
| 494 | implicit none | ||
| 495 | |||
| 496 | ! Arguments | ||
| 497 | class(neural_operator_layer_type), intent(inout) :: this | ||
| 498 | !! Instance of the neural operator layer | ||
| 499 | integer, intent(in) :: unit | ||
| 500 | !! Unit number | ||
| 501 | integer, optional, intent(in) :: verbose | ||
| 502 | !! Verbosity level | ||
| 503 | |||
| 504 | ! Local variables | ||
| 505 | integer :: stat | ||
| 506 | !! Status of read | ||
| 507 | integer :: verbose_ = 0 | ||
| 508 | !! Verbosity level | ||
| 509 | integer :: i, j, k, c, itmp1, iline, num_params | ||
| 510 | !! Loop variables and temporary integers | ||
| 511 | integer :: num_inputs, num_outputs | ||
| 512 | !! Number of inputs and outputs | ||
| 513 | logical :: use_bias = .true. | ||
| 514 | !! Whether to use bias | ||
| 515 | character(14) :: kernel_initialiser_name='', bias_initialiser_name='' | ||
| 516 | !! Initialiser names | ||
| 517 | character(20) :: activation_name='' | ||
| 518 | !! Activation function name | ||
| 519 | 3 | class(base_actv_type), allocatable :: activation | |
| 520 | !! Activation function | ||
| 521 | 5 | class(base_init_type), allocatable :: kernel_initialiser, bias_initialiser | |
| 522 | !! Initialisers | ||
| 523 | character(256) :: buffer, tag, err_msg | ||
| 524 | !! Buffer, tag, and error message | ||
| 525 | 1 | real(real32), allocatable, dimension(:) :: data_list | |
| 526 | !! Data list | ||
| 527 | integer :: param_line, final_line | ||
| 528 | !! Parameter line numbers | ||
| 529 | |||
| 530 | |||
| 531 | ! Initialise optional arguments | ||
| 532 | !--------------------------------------------------------------------------- | ||
| 533 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(verbose)) verbose_ = verbose |
| 534 | |||
| 535 | |||
| 536 | ! Loop over tags in layer card | ||
| 537 | !--------------------------------------------------------------------------- | ||
| 538 | 1 | iline = 0 | |
| 539 | 1 | param_line = 0 | |
| 540 | 1 | final_line = 0 | |
| 541 | 11 | tag_loop: do | |
| 542 | |||
| 543 | ! Check for end of file | ||
| 544 | !------------------------------------------------------------------------ | ||
| 545 | 12 | read(unit,'(A)',iostat=stat) buffer | |
| 546 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
12 | if(stat.ne.0)then |
| 547 | write(err_msg,'("file encountered error (EoF?) before END ",A)') & | ||
| 548 | ✗ | to_upper(this%name) | |
| 549 | ✗ | call stop_program(err_msg) | |
| 550 | ✗ | return | |
| 551 | end if | ||
| 552 |
2/4✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 12 times.
|
12 | if(trim(adjustl(buffer)).eq."") cycle tag_loop |
| 553 | |||
| 554 | ! Check for end of layer card | ||
| 555 | !------------------------------------------------------------------------ | ||
| 556 |
4/6✓ Branch 3 taken 12 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 11 times.
|
24 | if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then |
| 557 | 1 | final_line = iline | |
| 558 | 1 | backspace(unit) | |
| 559 | 12 | exit tag_loop | |
| 560 | end if | ||
| 561 | 11 | iline = iline + 1 | |
| 562 | |||
| 563 |
2/4✓ Branch 2 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✗ Branch 5 not taken.
|
11 | tag=trim(adjustl(buffer)) |
| 564 |
6/10✓ Branch 0 taken 3 times.
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✓ Branch 8 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
|
11 | if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) |
| 565 | |||
| 566 | ! Read parameters from file | ||
| 567 | !------------------------------------------------------------------------ | ||
| 568 | 22 | select case(trim(tag)) | |
| 569 | case("NUM_INPUTS") | ||
| 570 | 2 | call assign_val(buffer, num_inputs, itmp1) | |
| 571 | case("NUM_OUTPUTS") | ||
| 572 | 2 | call assign_val(buffer, num_outputs, itmp1) | |
| 573 | case("USE_BIAS") | ||
| 574 | 2 | call assign_val(buffer, use_bias, itmp1) | |
| 575 | case("ACTIVATION") | ||
| 576 | 1 | iline = iline - 1 | |
| 577 | 1 | backspace(unit) | |
| 578 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
|
1 | activation = read_activation(unit, iline) |
| 579 | case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALIZER") | ||
| 580 | ✗ | call assign_val(buffer, kernel_initialiser_name, itmp1) | |
| 581 | case("BIAS_INITIALISER", "BIAS_INIT", "BIAS_INITIALIZER") | ||
| 582 | ✗ | call assign_val(buffer, bias_initialiser_name, itmp1) | |
| 583 | case("WEIGHTS") | ||
| 584 | 1 | kernel_initialiser_name = 'zeros' | |
| 585 | 1 | bias_initialiser_name = 'zeros' | |
| 586 | 1 | param_line = iline | |
| 587 | case default | ||
| 588 | ! Skip lines that only contain numbers (e.g. scientific notation) | ||
| 589 |
3/4✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✓ Branch 5 taken 1 times.
|
12 | if(scan(to_lower(trim(adjustl(buffer))),& |
| 590 | 'abcdfghijklmnopqrstuvwxyz').eq.0)then | ||
| 591 | 6 | cycle tag_loop | |
| 592 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | elseif(tag(:3).eq.'END')then |
| 593 | 6 | cycle tag_loop | |
| 594 | end if | ||
| 595 | write(err_msg,'("Unrecognised line in input file: ",A)') & | ||
| 596 | ✗ | trim(adjustl(buffer)) | |
| 597 | ✗ | call stop_program(err_msg) | |
| 598 |
7/10✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 6 times.
|
22 | return |
| 599 | end select | ||
| 600 | end do tag_loop | ||
| 601 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
|
1 | kernel_initialiser = initialiser_setup(kernel_initialiser_name) |
| 602 |
5/14✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
|
1 | bias_initialiser = initialiser_setup(bias_initialiser_name) |
| 603 | |||
| 604 | |||
| 605 | ! Set hyperparameters and initialise layer | ||
| 606 | !--------------------------------------------------------------------------- | ||
| 607 | call this%set_hyperparams( & | ||
| 608 | num_outputs = num_outputs, & | ||
| 609 | use_bias = use_bias, & | ||
| 610 | activation = activation, & | ||
| 611 | kernel_initialiser = kernel_initialiser, & | ||
| 612 | bias_initialiser = bias_initialiser, & | ||
| 613 | verbose = verbose_ & | ||
| 614 | 1 | ) | |
| 615 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | call this%init(input_shape=[num_inputs]) |
| 616 | |||
| 617 | |||
| 618 | ! Read weights if WEIGHTS card was found | ||
| 619 | !--------------------------------------------------------------------------- | ||
| 620 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(param_line.eq.0)then |
| 621 | ✗ | write(0,*) "WARNING: WEIGHTS card in " // trim(this%name) // " not found" | |
| 622 | else | ||
| 623 | 1 | call move(unit, param_line - iline, iostat=stat) | |
| 624 | |||
| 625 | ! Read W (num_inputs * num_outputs elements) | ||
| 626 | 1 | num_params = this%num_inputs * this%num_outputs | |
| 627 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | allocate(data_list(num_params), source=0._real32) |
| 628 | 1 | c = 1 | |
| 629 | 1 | k = 1 | |
| 630 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 3 times.
|
4 | data_concat_loop: do while(c.le.num_params) |
| 631 | 3 | read(unit,'(A)',iostat=stat) buffer | |
| 632 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | if(stat.ne.0) exit data_concat_loop |
| 633 | 3 | k = icount(buffer) | |
| 634 |
5/8✗ Branch 1 not taken.
✓ Branch 2 taken 15 times.
✓ Branch 3 taken 12 times.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 12 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 12 times.
|
15 | read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1) |
| 635 | 3 | c = c + k | |
| 636 | end do data_concat_loop | ||
| 637 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 12 times.
✓ Branch 40 taken 1 times.
|
13 | this%params(1)%val(:,1) = data_list |
| 638 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | deallocate(data_list) |
| 639 | |||
| 640 | ! Read W_k (num_outputs elements) | ||
| 641 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | allocate(data_list(this%num_outputs), source=0._real32) |
| 642 | 1 | c = 1 | |
| 643 | 1 | k = 1 | |
| 644 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | data_concat_loop2: do while(c.le.this%num_outputs) |
| 645 | 1 | read(unit,'(A)',iostat=stat) buffer | |
| 646 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(stat.ne.0) exit data_concat_loop2 |
| 647 | 1 | k = icount(buffer) | |
| 648 |
5/8✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
|
5 | read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1) |
| 649 | 1 | c = c + k | |
| 650 | end do data_concat_loop2 | ||
| 651 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
5 | this%params(2)%val(:,1) = data_list(1:this%num_outputs) |
| 652 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | deallocate(data_list) |
| 653 | |||
| 654 | ! Read b (num_outputs elements, only if use_bias) | ||
| 655 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(use_bias)then |
| 656 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | allocate(data_list(num_outputs), source=0._real32) |
| 657 | 1 | c = 1 | |
| 658 | 1 | k = 1 | |
| 659 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | data_concat_loop3: do while(c.le.num_outputs) |
| 660 | 1 | read(unit,'(A)',iostat=stat) buffer | |
| 661 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(stat.ne.0) exit data_concat_loop3 |
| 662 | 1 | k = icount(buffer) | |
| 663 |
5/8✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
|
5 | read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1) |
| 664 | 1 | c = c + k | |
| 665 | end do data_concat_loop3 | ||
| 666 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
5 | this%params(3)%val(:,1) = data_list(1:num_outputs) |
| 667 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | deallocate(data_list) |
| 668 | end if | ||
| 669 | |||
| 670 | ! Check for END WEIGHTS tag | ||
| 671 | !------------------------------------------------------------------------ | ||
| 672 | 1 | read(unit,'(A)') buffer | |
| 673 |
2/4✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
|
1 | if(trim(adjustl(buffer)).ne."END WEIGHTS")then |
| 674 | ✗ | write(0,*) trim(adjustl(buffer)) | |
| 675 | ✗ | call stop_program("END WEIGHTS not where expected") | |
| 676 | ✗ | return | |
| 677 | end if | ||
| 678 | end if | ||
| 679 | |||
| 680 | |||
| 681 | !--------------------------------------------------------------------------- | ||
| 682 | ! Check for end of layer card | ||
| 683 | !--------------------------------------------------------------------------- | ||
| 684 | 1 | call move(unit, final_line - iline, iostat=stat) | |
| 685 | 1 | read(unit,'(A)') buffer | |
| 686 |
3/6✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
|
2 | if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then |
| 687 | ✗ | write(0,*) trim(adjustl(buffer)) | |
| 688 | ✗ | write(err_msg,'("END ",A," not where expected")') to_upper(this%name) | |
| 689 | ✗ | call stop_program(err_msg) | |
| 690 | 1 | return | |
| 691 | end if | ||
| 692 | |||
| 693 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | end subroutine read_neural_operator |
| 694 | !############################################################################### | ||
| 695 | |||
| 696 | |||
| 697 | !############################################################################### | ||
| 698 | 1 | function read_neural_operator_layer(unit, verbose) result(layer) | |
| 699 | !! Read neural operator layer from file and return as base_layer_type | ||
| 700 | implicit none | ||
| 701 | |||
| 702 | ! Arguments | ||
| 703 | integer, intent(in) :: unit | ||
| 704 | !! Unit number | ||
| 705 | integer, optional, intent(in) :: verbose | ||
| 706 | !! Verbosity level | ||
| 707 | class(base_layer_type), allocatable :: layer | ||
| 708 | !! Allocated layer instance | ||
| 709 | |||
| 710 | ! Local variables | ||
| 711 | integer :: verbose_ = 0 | ||
| 712 | |||
| 713 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(verbose)) verbose_ = verbose |
| 714 |
21/78✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✓ Branch 48 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 1 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✓ Branch 65 taken 1 times.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✓ Branch 68 taken 1 times.
✓ Branch 70 taken 1 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 1 times.
✗ Branch 73 not taken.
✗ Branch 74 not taken.
✓ Branch 75 taken 1 times.
✓ Branch 77 taken 1 times.
✗ Branch 78 not taken.
✓ Branch 79 taken 1 times.
✗ Branch 80 not taken.
✗ Branch 81 not taken.
✓ Branch 82 taken 1 times.
✓ Branch 84 taken 1 times.
✗ Branch 85 not taken.
|
2 | allocate(layer, source=neural_operator_layer_type(num_outputs=0)) |
| 715 | 1 | call layer%read(unit, verbose=verbose_) | |
| 716 | |||
| 717 | 2 | end function read_neural_operator_layer | |
| 718 | !############################################################################### | ||
| 719 | |||
| 720 | |||
| 721 | !##############################################################################! | ||
| 722 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 723 | !##############################################################################! | ||
| 724 | |||
| 725 | |||
| 726 | !############################################################################### | ||
| 727 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | subroutine forward_neural_operator(this, input) |
| 728 | !! Forward propagation for the neural operator layer | ||
| 729 | !! | ||
| 730 | !! Computes: | ||
| 731 | !! v = sigma( W * u + W_k * mean(u) + b ) | ||
| 732 | !! | ||
| 733 | !! where mean(u) is the global mean of the input (scalar per sample), | ||
| 734 | !! approximating the integral operator. | ||
| 735 | implicit none | ||
| 736 | |||
| 737 | ! Arguments | ||
| 738 | class(neural_operator_layer_type), intent(inout) :: this | ||
| 739 | !! Instance of the neural operator layer | ||
| 740 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 741 | !! Input values | ||
| 742 | |||
| 743 | ! Local variables | ||
| 744 | type(array_type), pointer :: ptr, ptr_mean, ptr_kern | ||
| 745 | |||
| 746 | |||
| 747 | ! Local transform: W · u → shape [n_out] | ||
| 748 | !--------------------------------------------------------------------------- | ||
| 749 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
|
4 | ptr => matmul(this%params(1), input(1,1)) |
| 750 | |||
| 751 | ! Integral (mean-field) term: W_k · mean(u) → shape [n_out] | ||
| 752 | ! mean(input, dim=1) reduces over all spatial elements, giving a scalar | ||
| 753 | ! per batch sample (shape [1]). matmul then expands W_k ([n_out x 1]) | ||
| 754 | ! by this scalar to produce a [n_out] correction vector. | ||
| 755 | !--------------------------------------------------------------------------- | ||
| 756 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | ptr_mean => mean(input(1,1), dim=1) |
| 757 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | ptr_kern => matmul(this%params(2), ptr_mean) |
| 758 | |||
| 759 | ! Combine local + integral terms | ||
| 760 | !--------------------------------------------------------------------------- | ||
| 761 | 4 | ptr => ptr + ptr_kern | |
| 762 | |||
| 763 | ! Add bias if used | ||
| 764 | !--------------------------------------------------------------------------- | ||
| 765 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if(this%use_bias)then |
| 766 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
|
4 | ptr => ptr + this%params(3) |
| 767 | end if | ||
| 768 | |||
| 769 | ! Apply activation function | ||
| 770 | !--------------------------------------------------------------------------- | ||
| 771 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
|
4 | call this%output(1,1)%zero_grad() |
| 772 |
3/4✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 3 times.
|
4 | if(trim(this%activation%name) .eq. "none")then |
| 773 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
|
1 | call this%output(1,1)%assign_and_deallocate_source(ptr) |
| 774 | else | ||
| 775 | 3 | call this%z(1)%zero_grad() | |
| 776 | 3 | call this%z(1)%assign_and_deallocate_source(ptr) | |
| 777 | 3 | this%z(1)%is_temporary = .false. | |
| 778 | 3 | ptr => this%activation%apply(this%z(1)) | |
| 779 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
|
3 | call this%output(1,1)%assign_and_deallocate_source(ptr) |
| 780 | end if | ||
| 781 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
|
4 | this%output(1,1)%is_temporary = .false. |
| 782 | |||
| 783 | 4 | end subroutine forward_neural_operator | |
| 784 | !############################################################################### | ||
| 785 | |||
| 786 | |||
| 787 | !############################################################################### | ||
| 788 | 2 | function get_attributes_neural_operator(this) result(attributes) | |
| 789 | !! Return list of neural operator attributes for ONNX export | ||
| 790 | implicit none | ||
| 791 | |||
| 792 | ! Arguments | ||
| 793 | class(neural_operator_layer_type), intent(in) :: this | ||
| 794 | !! Instance of the neural operator layer | ||
| 795 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 796 | !! List of attributes for ONNX export | ||
| 797 | |||
| 798 | ! Local variables | ||
| 799 | character(32) :: buffer | ||
| 800 | !! Buffer for integer-to-string conversion | ||
| 801 | |||
| 802 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 8 times.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 8 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 8 times.
|
10 | allocate(attributes(4)) |
| 803 | |||
| 804 | 2 | write(buffer, '(I0)') this%num_inputs | |
| 805 | ✗ | attributes(1) = onnx_attribute_type( & | |
| 806 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
|
2 | name='num_inputs', type='int', val=trim(buffer)) |
| 807 | 2 | write(buffer, '(I0)') this%num_outputs | |
| 808 | ✗ | attributes(2) = onnx_attribute_type( & | |
| 809 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
|
2 | name='num_outputs', type='int', val=trim(buffer)) |
| 810 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
2 | if(this%use_bias)then |
| 811 | 2 | buffer = '1' | |
| 812 | else | ||
| 813 | ✗ | buffer = '0' | |
| 814 | end if | ||
| 815 | ✗ | attributes(3) = onnx_attribute_type( & | |
| 816 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
|
2 | name='use_bias', type='int', val=trim(buffer)) |
| 817 | ✗ | attributes(4) = onnx_attribute_type( & | |
| 818 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
|
2 | name='activation', type='string', val=trim(this%activation%name)) |
| 819 | |||
| 820 | 2 | end function get_attributes_neural_operator | |
| 821 | !############################################################################### | ||
| 822 | |||
| 823 | |||
| 824 | !############################################################################### | ||
| 825 | 1 | subroutine emit_onnx_nodes_neural_operator( & | |
| 826 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | this, prefix, nodes, num_nodes, max_nodes, inits, num_inits, & |
| 827 | max_inits, input_name, is_last_layer, format) | ||
| 828 | !! Emit decomposed standard ONNX nodes for a Neural Operator layer. | ||
| 829 | !! | ||
| 830 | !! Forward: v = sigma(W * u + w_k * mean(u) + b) | ||
| 831 | implicit none | ||
| 832 | |||
| 833 | ! Arguments | ||
| 834 | class(neural_operator_layer_type), intent(in) :: this | ||
| 835 | !! Neural operator layer instance | ||
| 836 | character(*), intent(in) :: prefix | ||
| 837 | !! Layer name prefix | ||
| 838 | type(onnx_node_type), intent(inout), dimension(:) :: nodes | ||
| 839 | !! Node accumulator | ||
| 840 | integer, intent(inout) :: num_nodes | ||
| 841 | !! Node counter | ||
| 842 | integer, intent(in) :: max_nodes | ||
| 843 | !! Node limit | ||
| 844 | type(onnx_initialiser_type), intent(inout), dimension(:) :: inits | ||
| 845 | !! Initialiser accumulator | ||
| 846 | integer, intent(inout) :: num_inits | ||
| 847 | !! Initialiser counter | ||
| 848 | integer, intent(in) :: max_inits | ||
| 849 | !! Initialiser limit | ||
| 850 | character(*), optional, intent(in) :: input_name | ||
| 851 | !! Name of the input tensor | ||
| 852 | logical, optional, intent(in) :: is_last_layer | ||
| 853 | !! Whether this is the last layer | ||
| 854 | integer, optional, intent(in) :: format | ||
| 855 | !! Export format selector | ||
| 856 | |||
| 857 | ! Local variables | ||
| 858 | integer :: n | ||
| 859 | character(128) :: w_name, wk_name, b_name | ||
| 860 | character(128) :: trans_in_out, mm_w_out, reduce_out | ||
| 861 | character(128) :: mul_out, add_out, add_b_out, final_output, & | ||
| 862 | output_source | ||
| 863 | character(4096) :: reduce_attr | ||
| 864 | integer :: format_ | ||
| 865 | |||
| 866 | 1 | format_ = 1 | |
| 867 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(format)) format_ = format |
| 868 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(format_ .ne. 2) return |
| 869 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(.not.present(input_name)) return |
| 870 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(.not.present(is_last_layer)) return |
| 871 | |||
| 872 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(w_name, '(A,".W")') trim(prefix) |
| 873 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(wk_name, '(A,".w_k")') trim(prefix) |
| 874 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(b_name, '(A,".b")') trim(prefix) |
| 875 | |||
| 876 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(trans_in_out, '("/",A,"/Transpose_output_0")') trim(prefix) |
| 877 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(mm_w_out, '("/",A,"/MatMul_output_0")') trim(prefix) |
| 878 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(reduce_out, '("/",A,"/ReduceMean_output_0")') trim(prefix) |
| 879 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(mul_out, '("/",A,"/Mul_output_0")') trim(prefix) |
| 880 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(add_out, '("/",A,"/Add_output_0")') trim(prefix) |
| 881 |
1/2✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | write(add_b_out, '("/",A,"/Add_1_output_0")') trim(prefix) |
| 882 | |||
| 883 | reduce_attr = ' "attribute": [{"name": "axes", "ints": ' // & | ||
| 884 | '["0"], "type": "INTS"}, {"name": "keepdims", "i": "1", ' // & | ||
| 885 | 1 | '"type": "INT"}]' | |
| 886 | |||
| 887 | ! Transpose(input) | ||
| 888 | ✗ | call emit_nop_input_transpose(trim(prefix), trim(input_name), nodes, & | |
| 889 |
6/12✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 15 not taken.
|
1 | num_nodes, trim(trans_in_out)) |
| 890 | |||
| 891 | ! MatMul(W, x_t) | ||
| 892 | 1 | num_nodes = num_nodes + 1 | |
| 893 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
|
1 | write(nodes(num_nodes)%name, '("/",A,"/MatMul")') trim(prefix) |
| 894 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%op_type = 'MatMul' |
| 895 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%inputs(2)) |
| 896 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(1) = trim(w_name) |
| 897 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(2) = trim(trans_in_out) |
| 898 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%outputs(1)) |
| 899 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%outputs(1) = trim(mm_w_out) |
| 900 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%attributes_json = '' |
| 901 | |||
| 902 | ! ReduceMean(x_t, axis=0) | ||
| 903 | 1 | num_nodes = num_nodes + 1 | |
| 904 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
|
1 | write(nodes(num_nodes)%name, '("/",A,"/ReduceMean")') trim(prefix) |
| 905 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%op_type = 'ReduceMean' |
| 906 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%inputs(1)) |
| 907 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(1) = trim(trans_in_out) |
| 908 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%outputs(1)) |
| 909 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%outputs(1) = trim(reduce_out) |
| 910 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%attributes_json = reduce_attr |
| 911 | |||
| 912 | ! Mul(w_k, mean) | ||
| 913 | 1 | num_nodes = num_nodes + 1 | |
| 914 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
|
1 | write(nodes(num_nodes)%name, '("/",A,"/Mul")') trim(prefix) |
| 915 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%op_type = 'Mul' |
| 916 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%inputs(2)) |
| 917 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(1) = trim(wk_name) |
| 918 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(2) = trim(reduce_out) |
| 919 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%outputs(1)) |
| 920 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%outputs(1) = trim(mul_out) |
| 921 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%attributes_json = '' |
| 922 | |||
| 923 | ! Add(local, kernel) | ||
| 924 | 1 | num_nodes = num_nodes + 1 | |
| 925 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
|
1 | write(nodes(num_nodes)%name, '("/",A,"/Add")') trim(prefix) |
| 926 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%op_type = 'Add' |
| 927 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%inputs(2)) |
| 928 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(1) = trim(mm_w_out) |
| 929 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(2) = trim(mul_out) |
| 930 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%outputs(1)) |
| 931 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%outputs(1) = trim(add_out) |
| 932 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%attributes_json = '' |
| 933 | |||
| 934 | ! Add(combined, bias) | ||
| 935 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%use_bias)then |
| 936 | 1 | num_nodes = num_nodes + 1 | |
| 937 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
|
1 | write(nodes(num_nodes)%name, '("/",A,"/Add_1")') trim(prefix) |
| 938 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%op_type = 'Add' |
| 939 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%inputs(2)) |
| 940 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(1) = trim(add_out) |
| 941 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%inputs(2) = trim(b_name) |
| 942 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
1 | allocate(nodes(num_nodes)%outputs(1)) |
| 943 |
6/12✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
|
1 | nodes(num_nodes)%outputs(1) = trim(add_b_out) |
| 944 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | nodes(num_nodes)%attributes_json = '' |
| 945 | end if | ||
| 946 | |||
| 947 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%use_bias)then |
| 948 | 1 | output_source = add_b_out | |
| 949 | else | ||
| 950 | ✗ | output_source = add_out | |
| 951 | end if | ||
| 952 | call emit_nop_output_tail(trim(prefix), trim(this%activation%name), & | ||
| 953 |
6/12✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
|
1 | is_last_layer, trim(output_source), nodes, num_nodes, final_output) |
| 954 | |||
| 955 | ! Initialisers | ||
| 956 | 1 | n = this%num_outputs * this%num_inputs | |
| 957 | 10 | call emit_matrix_initialiser(trim(w_name), this%params(1)%val(:,1), & | |
| 958 |
14/28✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
|
1 | this%num_outputs, this%num_inputs, inits, num_inits) |
| 959 | |||
| 960 | ! w_k: mean-field kernel [n_out, 1] | ||
| 961 | 10 | call emit_float_initialiser(trim(wk_name), this%params(2)%val(:,1), & | |
| 962 |
16/30✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✓ Branch 31 taken 2 times.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
|
3 | [this%num_outputs, 1], inits, num_inits) |
| 963 | |||
| 964 | ! b: bias [n_out, 1] | ||
| 965 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%use_bias)then |
| 966 | 10 | call emit_float_initialiser(trim(b_name), this%params(3)%val(:,1), & | |
| 967 |
16/30✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✓ Branch 31 taken 2 times.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
|
3 | [this%num_outputs, 1], inits, num_inits) |
| 968 | end if | ||
| 969 | |||
| 970 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | end subroutine emit_onnx_nodes_neural_operator |
| 971 | !############################################################################### | ||
| 972 | |||
| 973 |
42/91✓ Branch 0 taken 24 times.
✓ Branch 1 taken 26 times.
✓ Branch 2 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 26 times.
✓ Branch 5 taken 24 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 22 times.
✓ Branch 37 taken 2 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 22 times.
✓ Branch 40 taken 26 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 22 times.
✓ Branch 43 taken 26 times.
✓ Branch 44 taken 22 times.
✓ Branch 45 taken 48 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 22 times.
✓ Branch 48 taken 28 times.
✓ Branch 49 taken 48 times.
✓ Branch 50 taken 2 times.
✓ Branch 51 taken 48 times.
✓ Branch 52 taken 2 times.
✓ Branch 53 taken 20 times.
✓ Branch 54 taken 4 times.
✓ Branch 55 taken 22 times.
✓ Branch 56 taken 2 times.
✓ Branch 57 taken 64 times.
✓ Branch 58 taken 22 times.
✓ Branch 59 taken 64 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 64 times.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✓ Branch 64 taken 64 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 64 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 64 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 64 times.
✓ Branch 71 taken 24 times.
✗ Branch 72 not taken.
✓ Branch 74 taken 24 times.
✗ Branch 75 not taken.
✓ Branch 77 taken 24 times.
✗ Branch 78 not taken.
✓ Branch 80 taken 24 times.
✓ Branch 81 taken 24 times.
✗ Branch 82 not taken.
✓ Branch 83 taken 24 times.
✗ Branch 84 not taken.
✓ Branch 85 taken 24 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 24 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 24 times.
✗ Branch 90 not taken.
✓ Branch 91 taken 24 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 24 times.
|
460 | end module athena__neural_operator_layer |
| 974 |