GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_batchnorm1d_layer.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 167 220 75.9%
Functions: 0 0 -%
Branches: 437 1081 40.4%

Line Branch Exec Source
1 module athena__batchnorm1d_layer
2 !! Module containing implementation of 0D and 1D batch normalisation layers
3 !!
4 !! This module implements batch normalisation for 3D convolutional layers,
5 !! normalizing activations across the batch dimension.
6 !!
7 !! Mathematical operation (training):
8 !! \[ \mu_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} x_i \]
9 !! \[ \sigma^2_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_\mathcal{B})^2 \]
10 !! \[ \hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}} \]
11 !! \[ y_i = \gamma \hat{x}_i + \beta \]
12 !!
13 !! where \(\gamma, \beta\) are learnable parameters, \(\epsilon\) is stability constant
14 !!
15 !! Inference: uses running statistics
16 !! \(\mu_{\text{running}}, \sigma^2_{\text{running}}\) from training
17 !!
18 !! Benefits: Reduces internal covariate shift, enables higher learning rates,
19 !! acts as regularisation, reduces dependence on initialisation
20 !! Reference: Ioffe & Szegedy (2015), ICML
21 use coreutils, only: real32, stop_program, print_warning
22 use athena__base_layer, only: batch_layer_type, base_layer_type
23 use athena__misc_types, only: base_init_type, &
24 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
25 use diffstruc, only: array_type
26 use athena__diffstruc_extd, only: batchnorm_array_type, &
27 batchnorm, batchnorm_inference
28 implicit none
29
30
31 private
32
33 public :: batchnorm1d_layer_type
34 public :: read_batchnorm1d_layer
35
36
37 type, extends(batch_layer_type) :: batchnorm1d_layer_type
38 !! Type for 0D or 1D batch normalisation layer with overloaded procedures
39 contains
40 procedure, pass(this) :: set_hyperparams => set_hyperparams_batchnorm1d
41 !! Set hyperparameters for 1D batch normalisation layer
42 procedure, pass(this) :: read => read_batchnorm1d
43 !! Read 1D batch normalisation layer from file
44
45 procedure, pass(this) :: forward => forward_batchnorm1d
46 !! Forward propagation derived type handler
47
48 final :: finalise_batchnorm1d
49 !! Finalise 1D batch normalisation layer
50 end type batchnorm1d_layer_type
51
52
53 interface batchnorm1d_layer_type
54 !! Interface for setting up the 1D batch normalisation layer
55 module function layer_setup( &
56 input_shape, &
57 num_channels, num_inputs, &
58 momentum, epsilon, &
59 gamma_init_mean, gamma_init_std, &
60 beta_init_mean, beta_init_std, &
61 gamma_initialiser, beta_initialiser, &
62 moving_mean_initialiser, moving_variance_initialiser, &
63 verbose &
64 ) result(layer)
65 !! Set up the 1D batch normalisation layer
66 integer, dimension(:), optional, intent(in) :: input_shape
67 !! Input shape
68 integer, optional, intent(in) :: num_channels, num_inputs
69 !! Number of channels and inputs
70 real(real32), optional, intent(in) :: momentum, epsilon
71 !! Momentum and epsilon
72 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
73 !! Gamma initialisation mean and standard deviation
74 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
75 !! Beta initialisation mean and standard deviation
76 class(*), optional, intent(in) :: &
77 gamma_initialiser, beta_initialiser, &
78 moving_mean_initialiser, moving_variance_initialiser
79 !! Initialisers
80 integer, optional, intent(in) :: verbose
81 !! Verbosity level
82 type(batchnorm1d_layer_type) :: layer
83 !! Instance of the 1D batch normalisation layer
84 end function layer_setup
85 end interface batchnorm1d_layer_type
86
87
88
89 contains
90
91 !###############################################################################
92 6 subroutine finalise_batchnorm1d(this)
93 !! Finalise 1D batch normalisation layer
94 implicit none
95
96 ! Arguments
97 type(batchnorm1d_layer_type), intent(inout) :: this
98 !! Instance of the 1D batch normalisation layer
99
100
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
6 if(allocated(this%mean)) deallocate(this%mean)
101
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
6 if(allocated(this%variance)) deallocate(this%variance)
102
2/4
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
6 if(allocated(this%input_shape)) deallocate(this%input_shape)
103
3/6
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 6 times.
✓ Branch 5 taken 6 times.
✗ Branch 6 not taken.
6 if(allocated(this%output)) deallocate(this%output)
104
105 6 end subroutine finalise_batchnorm1d
106 !###############################################################################
107
108
109 !##############################################################################!
110 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
111 !##############################################################################!
112
113
114 !###############################################################################
115 7 module function layer_setup( &
116
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
7 input_shape, &
117 num_channels, num_inputs, &
118 momentum, epsilon, &
119 gamma_init_mean, gamma_init_std, &
120 beta_init_mean, beta_init_std, &
121 gamma_initialiser, beta_initialiser, &
122 moving_mean_initialiser, moving_variance_initialiser, &
123 verbose &
124 7 ) result(layer)
125 !! Set up the 1D batch normalisation layer
126 use athena__initialiser, only: initialiser_setup
127 implicit none
128
129 ! Arguments
130 integer, dimension(:), optional, intent(in) :: input_shape
131 !! Input shape
132 integer, optional, intent(in) :: num_channels, num_inputs
133 !! Number of channels and inputs
134 real(real32), optional, intent(in) :: momentum, epsilon
135 !! Momentum and epsilon
136 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
137 !! Gamma initialisation mean and standard deviation
138 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
139 !! Beta initialisation mean and standard deviation
140 class(*), optional, intent(in) :: &
141 gamma_initialiser, beta_initialiser, &
142 moving_mean_initialiser, moving_variance_initialiser
143 !! Initialisers
144 integer, optional, intent(in) :: verbose
145 !! Verbosity level
146
147 type(batchnorm1d_layer_type) :: layer
148 !! Instance of the 1D batch normalisation layer
149
150 ! Local variables
151 integer :: verbose_ = 0
152 !! Verbosity level
153 character(256) :: err_msg
154 !! Error message
155 class(base_init_type), allocatable :: &
156 23 gamma_initialiser_, beta_initialiser_, &
157 23 moving_mean_initialiser_, moving_variance_initialiser_
158 !! Initialisers
159
160
161
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(present(verbose)) verbose_ = verbose
162
163 !---------------------------------------------------------------------------
164 ! Set up momentum and epsilon
165 !---------------------------------------------------------------------------
166
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(momentum))then
167 1 layer%momentum = momentum
168 else
169 6 layer%momentum = 0._real32
170 end if
171
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(epsilon))then
172 1 layer%epsilon = epsilon
173 else
174 6 layer%epsilon = 1.E-5_real32
175 end if
176
177
178 !---------------------------------------------------------------------------
179 ! Set up initialiser mean and standard deviations
180 !---------------------------------------------------------------------------
181
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(gamma_init_mean)) layer%gamma_init_mean = gamma_init_mean
182
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(gamma_init_std)) layer%gamma_init_std = gamma_init_std
183
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(beta_init_mean)) layer%beta_init_mean = beta_init_mean
184
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(beta_init_std)) layer%beta_init_std = beta_init_std
185
186
187 !---------------------------------------------------------------------------
188 ! Define gamma and beta initialisers
189 !---------------------------------------------------------------------------
190
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
7 if(present(gamma_initialiser))then
191
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 gamma_initialiser_ = initialiser_setup(gamma_initialiser)
192 end if
193
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
7 if(present(beta_initialiser))then
194
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 beta_initialiser_ = initialiser_setup(beta_initialiser)
195 end if
196
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
7 if(present(moving_mean_initialiser))then
197
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 moving_mean_initialiser_ = initialiser_setup(moving_mean_initialiser)
198 end if
199
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
7 if(present(moving_variance_initialiser))then
200
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 moving_variance_initialiser_ = initialiser_setup(moving_variance_initialiser)
201 end if
202
203
204 !---------------------------------------------------------------------------
205 ! Set hyperparameters
206 !---------------------------------------------------------------------------
207 call layer%set_hyperparams( &
208 momentum = layer%momentum, epsilon = layer%epsilon, &
209 gamma_init_mean = layer%gamma_init_mean, &
210 gamma_init_std = layer%gamma_init_std, &
211 beta_init_mean = layer%beta_init_mean, &
212 beta_init_std = layer%beta_init_std, &
213 gamma_initialiser = gamma_initialiser_, &
214 beta_initialiser = beta_initialiser_, &
215 moving_mean_initialiser = moving_mean_initialiser_, &
216 moving_variance_initialiser = moving_variance_initialiser_, &
217 verbose = verbose_ &
218 7 )
219
220
221 !---------------------------------------------------------------------------
222 ! Initialise layer shape
223 !---------------------------------------------------------------------------
224
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 3 times.
7 if(present(input_shape))then
225
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if(present(num_channels).or.present(num_inputs))then
226 write(err_msg,'(A)') &
227 "both input_shape and num_channels/num_inputs present" // &
228 achar(13) // achar(10) // &
229 "These represent the same parameter, so are conflicting"
230 call stop_program(err_msg)
231 return
232 end if
233
5/10
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
4 if(size(input_shape).eq.1)then
234
10/16
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 4 times.
✓ Branch 14 taken 4 times.
✓ Branch 15 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✓ Branch 18 taken 8 times.
✓ Branch 19 taken 4 times.
16 call layer%init(input_shape= [ 1, input_shape ] )
235 else
236 call layer%init(input_shape= input_shape)
237 end if
238
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
3 elseif(present(num_channels).and.present(num_inputs))then
239
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 2 times.
6 call layer%init(input_shape=[num_inputs, num_channels])
240
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 elseif(present(num_channels))then
241 call layer%init(input_shape=[1, num_channels])
242
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 elseif(present(num_inputs))then
243 call layer%init(input_shape=[num_inputs, 1])
244 end if
245
246
14/18
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 3 times.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 6 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✓ Branch 7 taken 1 times.
✓ Branch 8 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 6 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✓ Branch 18 taken 6 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
14 end function layer_setup
247 !###############################################################################
248
249
250 !###############################################################################
251 8 subroutine set_hyperparams_batchnorm1d( &
252 this, &
253 momentum, epsilon, &
254 gamma_init_mean, gamma_init_std, &
255 beta_init_mean, beta_init_std, &
256 gamma_initialiser, beta_initialiser, &
257 moving_mean_initialiser, moving_variance_initialiser, &
258 verbose )
259 !! Set hyperparameters for 1D batch normalisation layer
260 use athena__initialiser, only: initialiser_setup
261 implicit none
262
263 ! Arguments
264 class(batchnorm1d_layer_type), intent(inout) :: this
265 !! Instance of the 1D batch normalisation layer
266 real(real32), intent(in) :: momentum, epsilon
267 !! Momentum and epsilon
268 real(real32), intent(in) :: gamma_init_mean, gamma_init_std
269 !! Gamma initialisation mean and standard deviation
270 real(real32), intent(in) :: beta_init_mean, beta_init_std
271 !! Beta initialisation mean and standard deviation
272 class(base_init_type), allocatable, intent(in) :: &
273 gamma_initialiser, beta_initialiser
274 !! Gamma and beta initialisers
275 class(base_init_type), allocatable, intent(in) :: &
276 moving_mean_initialiser, moving_variance_initialiser
277 !! Moving mean and variance initialisers
278 integer, optional, intent(in) :: verbose
279 !! Verbosity level
280
281
282
5/8
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
8 this%name = "batchnorm1d"
283 8 this%type = "batc"
284 8 this%input_rank = 2
285 8 this%output_rank = 2
286 8 this%use_bias = .true.
287 8 this%momentum = momentum
288 8 this%epsilon = epsilon
289
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
8 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
290
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 if(.not.allocated(gamma_initialiser))then
291
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6 times.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
6 this%kernel_init = initialiser_setup('ones')
292 else
293
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 allocate(this%kernel_init, source=gamma_initialiser)
294 end if
295
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
8 if(allocated(this%bias_init)) deallocate(this%bias_init)
296
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 if(.not.allocated(beta_initialiser))then
297
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6 times.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
6 this%bias_init = initialiser_setup('zeros')
298 else
299
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 allocate(this%bias_init, source=beta_initialiser)
300 end if
301
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 if(.not.allocated(moving_mean_initialiser))then
302
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6 times.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
6 this%moving_mean_init = initialiser_setup('zeros')
303 else
304
7/10
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
2 this%moving_mean_init = moving_mean_initialiser
305 end if
306
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 if(.not.allocated(moving_variance_initialiser))then
307
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 6 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 6 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 6 times.
✓ Branch 17 taken 6 times.
✗ Branch 18 not taken.
6 this%moving_variance_init = initialiser_setup('ones')
308 else
309
7/10
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
2 this%moving_variance_init = moving_variance_initialiser
310 end if
311 8 this%gamma_init_mean = gamma_init_mean
312 8 this%gamma_init_std = gamma_init_std
313 8 this%beta_init_mean = beta_init_mean
314 8 this%beta_init_std = beta_init_std
315 8 this%kernel_init%mean = this%gamma_init_mean
316 8 this%kernel_init%std = this%gamma_init_std
317 8 this%bias_init%mean = this%beta_init_mean
318 8 this%bias_init%std = this%beta_init_std
319
1/2
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
8 if(present(verbose))then
320
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if(abs(verbose).gt.0)then
321 write(*,'("BATCHNORM1D gamma (kernel) initialiser: ",A)') &
322 trim(this%kernel_init%name)
323 write(*,'("BATCHNORM1D beta (bias) initialiser: ",A)') &
324 trim(this%bias_init%name)
325 write(*,'("BATCHNORM1D moving mean initialiser: ",A)') &
326 trim(this%moving_mean_init%name)
327 write(*,'("BATCHNORM1D moving variance initialiser: ",A)') &
328 trim(this%moving_variance_init%name)
329 end if
330 end if
331
332 8 end subroutine set_hyperparams_batchnorm1d
333 !###############################################################################
334
335
336 !##############################################################################!
337 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
338 !##############################################################################!
339
340
341 !###############################################################################
342 1 subroutine read_batchnorm1d(this, unit, verbose)
343 !! Read 1D batch normalisation layer from file
344 use athena__tools_infile, only: assign_val, assign_vec, move
345 use coreutils, only: to_lower, to_upper, icount
346 use athena__initialiser, only: initialiser_setup
347 implicit none
348
349 ! Arguments
350 class(batchnorm1d_layer_type), intent(inout) :: this
351 !! Instance of the 1D batch normalisation layer
352 integer, intent(in) :: unit
353 !! File unit
354 integer, optional, intent(in) :: verbose
355 !! Verbosity level
356
357 ! Local variables
358 integer :: stat, verbose_ = 0
359 !! File status and verbosity level
360 integer :: i, j, k, c, itmp1, iline, final_line
361 !! Temporary integers and loop indices
362 integer :: num_channels
363 !! Number of channels
364 real(real32) :: momentum = 0._real32, epsilon = 1.E-5_real32
365 !! Momentum and epsilon
366 5 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
367 !! Initialisers
368 class(base_init_type), allocatable :: &
369 5 moving_mean_initialiser, moving_variance_initialiser
370 !! Moving mean and variance initialisers
371 character(14) :: gamma_initialiser_name='', beta_initialiser_name=''
372 !! Initialisers
373 character(14) :: &
374 moving_mean_initialiser_name='', &
375 moving_variance_initialiser_name=''
376 !! Moving mean and variance initialisers
377 character(256) :: buffer, tag, err_msg
378 !! Buffer, tag, and error message
379 integer, dimension(2) :: input_shape
380 !! Input shape
381 1 real(real32), allocatable, dimension(:) :: data_list
382 !! Data list
383 integer, dimension(2) :: param_lines
384 !! Lines where parameters are found
385
386
387 ! Initialise optional arguments
388 !---------------------------------------------------------------------------
389
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(verbose)) verbose_ = verbose
390
391
392 ! Loop over tags in layer card
393 !---------------------------------------------------------------------------
394 1 iline = 0
395 1 param_lines = 0
396 1 final_line = 0
397 16 tag_loop: do
398
399 ! Check for end of file
400 !------------------------------------------------------------------------
401 17 read(unit,'(A)',iostat=stat) buffer
402
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
17 if(stat.ne.0)then
403 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
404 to_upper(this%name)
405 call stop_program(err_msg)
406 return
407 end if
408
2/4
✓ Branch 2 taken 17 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 17 times.
17 if(trim(adjustl(buffer)).eq."") cycle tag_loop
409
410 ! Check for end of layer card
411 !------------------------------------------------------------------------
412
4/6
✓ Branch 3 taken 17 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 17 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 16 times.
34 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
413 1 final_line = iline
414 1 backspace(unit)
415 17 exit tag_loop
416 end if
417 16 iline = iline + 1
418
419
2/4
✓ Branch 2 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 5 not taken.
16 tag = trim(adjustl(buffer))
420
6/10
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 8 times.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 11 not taken.
16 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
421
422 ! Read parameters from save file
423 !------------------------------------------------------------------------
424 32 select case(trim(tag))
425 case("INPUT_SHAPE")
426 1 call assign_vec(buffer, input_shape, itmp1)
427 case("MOMENTUM")
428 2 call assign_val(buffer, momentum, itmp1)
429 case("EPSILON")
430 2 call assign_val(buffer, epsilon, itmp1)
431 case("NUM_CHANNELS")
432 1 call assign_val(buffer, num_channels, itmp1)
433 1 write(0,*) "NUM_CHANNELS and INPUT_SHAPE are conflicting parameters"
434 1 write(0,*) "NUM_CHANNELS will be ignored"
435 case("GAMMA_INITIALISER", "KERNEL_INITIALISER")
436
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(param_lines(1).ne.0)then
437 write(err_msg,'("GAMMA and GAMMA_INITIALISER defined. Using GAMMA only.")')
438 call print_warning(err_msg)
439 end if
440 1 call assign_val(buffer, gamma_initialiser_name, itmp1)
441 case("BETA_INITIALISER", "BIAS_INITIALISER")
442
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(param_lines(2).ne.0)then
443 write(err_msg,'("BETA and BETA_INITIALISER defined. Using BETA only.")')
444 call print_warning(err_msg)
445 end if
446 1 call assign_val(buffer, beta_initialiser_name, itmp1)
447 case("MOVING_MEAN_INITIALISER")
448 1 call assign_val(buffer, moving_mean_initialiser_name, itmp1)
449 case("MOVING_VARIANCE_INITIALISER")
450 1 call assign_val(buffer, moving_variance_initialiser_name, itmp1)
451 case("GAMMA")
452 1 gamma_initialiser_name = 'zeros'
453 1 param_lines(1) = iline
454 case("BETA")
455 1 beta_initialiser_name = 'zeros'
456 1 param_lines(2) = iline
457 case default
458 ! Don't look for "e" due to scientific notation of numbers
459 ! ... i.e. exponent (E+00)
460
3/4
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 2 times.
12 if(scan(to_lower(trim(adjustl(buffer))),&
461 'abcdfghijklmnopqrstuvwxyz').eq.0)then
462 6 cycle tag_loop
463
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 elseif(tag(:3).eq.'END')then
464 6 cycle tag_loop
465 end if
466 write(err_msg,'("Unrecognised line in input file: ",A)') &
467 trim(adjustl(buffer))
468 call stop_program(err_msg)
469
12/13
✓ Branch 0 taken 16 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 1 times.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 1 times.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✓ Branch 12 taken 6 times.
32 return
470 end select
471 end do tag_loop
472
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 gamma_initialiser = initialiser_setup(gamma_initialiser_name)
473
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 beta_initialiser = initialiser_setup(beta_initialiser_name)
474
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 moving_mean_initialiser = initialiser_setup(moving_mean_initialiser_name)
475
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 moving_variance_initialiser = initialiser_setup(moving_variance_initialiser_name)
476
477
478 ! Set hyperparameters and initialise layer
479 !---------------------------------------------------------------------------
480 1 num_channels = input_shape(size(input_shape))
481 call this%set_hyperparams( &
482 momentum = momentum, &
483 epsilon = epsilon, &
484 gamma_init_mean = this%gamma_init_mean, &
485 gamma_init_std = this%gamma_init_std, &
486 beta_init_mean = this%beta_init_mean, &
487 beta_init_std = this%beta_init_std, &
488 gamma_initialiser = gamma_initialiser, &
489 beta_initialiser = beta_initialiser, &
490 moving_mean_initialiser = moving_mean_initialiser, &
491 moving_variance_initialiser = moving_variance_initialiser, &
492 verbose = verbose_ &
493 1 )
494 1 call this%init(input_shape = input_shape)
495
496
497 ! Check if WEIGHTS card was found
498 !---------------------------------------------------------------------------
499
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_channels), source=0._real32)
500
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
3 do i = 2, 1, -1
501
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
2 if(param_lines(i).eq.0) cycle
502
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 call move(unit, param_lines(i) - iline, iostat=stat)
503
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 iline = param_lines(i) + 1
504 2 c = 1
505 2 k = 1
506 2 data_list = 0._real32
507
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 4 times.
6 data_concat_loop: do while(c.le.num_channels)
508 4 iline = iline + 1
509 4 read(unit,'(A)',iostat=stat) buffer
510
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if(stat.ne.0) exit data_concat_loop
511 4 k = icount(buffer)
512
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 8 times.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
8 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
513 4 c = c + k
514 end do data_concat_loop
515 2 read(unit,'(A)',iostat=stat) buffer
516 1 select case(i)
517 case(1) ! gamma
518
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 2 times.
✓ Branch 40 taken 1 times.
3 this%params(1)%val(1:this%num_channels,1) = data_list
519
2/4
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
1 if(trim(adjustl(buffer)).ne."END GAMMA")then
520 write(err_msg,'("END GAMMA not where expected: ",A)') &
521 trim(adjustl(buffer))
522 call stop_program(err_msg)
523 return
524 end if
525 case(2) ! beta
526 9 this%params(1)%val(this%num_channels+1:this%num_channels*2,1) = &
527
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 2 times.
✓ Branch 40 taken 1 times.
3 data_list
528
4/7
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
3 if(trim(adjustl(buffer)).ne."END BETA")then
529 write(err_msg,'("END BETA not where expected: ",A)') &
530 trim(adjustl(buffer))
531 call stop_program(err_msg)
532 return
533 end if
534 end select
535 end do
536
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
537
538
539 ! Check for end of layer card
540 !---------------------------------------------------------------------------
541 1 call move(unit, final_line - iline, iostat=stat)
542 1 read(unit,'(A)') buffer
543
3/6
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
2 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
544 write(0,*) trim(adjustl(buffer))
545 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
546 call stop_program(err_msg)
547 1 return
548 end if
549
550
9/18
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
1 end subroutine read_batchnorm1d
551 !###############################################################################
552
553
554 !###############################################################################
555 1 function read_batchnorm1d_layer(unit, verbose) result(layer)
556 !! Read 1D batch normalisation layer from file and return layer
557 implicit none
558
559 ! Arguments
560 integer, intent(in) :: unit
561 !! File unit
562 integer, optional, intent(in) :: verbose
563 !! Verbosity level
564 class(base_layer_type), allocatable :: layer
565 !! Allocatable instance of the base layer
566
567 ! Local variables
568 integer :: verbose_ = 0
569 !! Verbosity level
570
571
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(verbose)) verbose_ = verbose
572
28/94
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✗ Branch 47 not taken.
✓ Branch 48 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 1 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✓ Branch 65 taken 1 times.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✓ Branch 68 taken 1 times.
✓ Branch 70 taken 1 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 1 times.
✗ Branch 73 not taken.
✗ Branch 74 not taken.
✓ Branch 75 taken 1 times.
✓ Branch 77 taken 1 times.
✗ Branch 78 not taken.
✗ Branch 79 not taken.
✓ Branch 80 taken 1 times.
✗ Branch 81 not taken.
✗ Branch 82 not taken.
✗ Branch 84 not taken.
✓ Branch 85 taken 1 times.
✓ Branch 86 taken 1 times.
✗ Branch 87 not taken.
✗ Branch 88 not taken.
✓ Branch 89 taken 1 times.
✓ Branch 91 taken 1 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 1 times.
✗ Branch 94 not taken.
✗ Branch 95 not taken.
✓ Branch 96 taken 1 times.
✓ Branch 98 taken 1 times.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✓ Branch 101 taken 1 times.
✗ Branch 102 not taken.
✓ Branch 103 taken 1 times.
2 allocate(layer, source=batchnorm1d_layer_type())
573 1 call layer%read(unit, verbose=verbose_)
574
575 2 end function read_batchnorm1d_layer
576 !###############################################################################
577
578
579 !###############################################################################
580 subroutine build_from_onnx_batchnorm1d( &
581 this, node, initialisers, value_info, verbose &
582 )
583 !! Read ONNX attributes for 1D batch normalisation layer
584 use athena__initialiser_data, only: data_init_type
585 implicit none
586
587 ! Arguments
588 class(batchnorm1d_layer_type), intent(inout) :: this
589 !! Instance of the 1D batch normalisation layer
590 type(onnx_node_type), intent(in) :: node
591 !! ONNX node information
592 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
593 !! ONNX initialiser information
594 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
595 !! ONNX value info
596 integer, intent(in) :: verbose
597 !! Verbosity level
598
599 ! Local variables
600 integer :: i
601 !! Loop index
602 real(real32) :: epsilon, momentum
603 !! Epsilon and momentum values
604 character(256) :: val
605 !! Attribute value
606 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
607 class(base_init_type), allocatable :: &
608 moving_mean_initialiser, moving_variance_initialiser
609
610 ! Set default values
611 epsilon = 1.E-5_real32
612 momentum = 0.9_real32
613
614 ! Parse ONNX attributes
615 do i = 1, size(node%attributes)
616 val = node%attributes(i)%val
617 select case(trim(adjustl(node%attributes(i)%name)))
618 case("epsilon")
619 read(val,*) epsilon
620 case("momentum")
621 read(val,*) momentum
622 case default
623 ! Do nothing
624 write(0,*) "WARNING: Unrecognised attribute in ONNX BATCHNORM1D &
625 &layer: ", trim(adjustl(node%attributes(i)%name))
626 end select
627 end do
628
629 ! Check for 4 initialisers: gamma, beta, mean, variance
630 if(size(initialisers).ne.4)then
631 call stop_program("ONNX BATCHNORM1D layer requires 4 initialisers &
632 &(gamma, beta, mean, variance)")
633 return
634 end if
635
636 ! ONNX BatchNormalisation order: gamma, beta, mean, variance
637 gamma_initialiser = data_init_type( data = initialisers(1)%data )
638 beta_initialiser = data_init_type( data = initialisers(2)%data )
639 moving_mean_initialiser = data_init_type( data = initialisers(3)%data )
640 moving_variance_initialiser = data_init_type( data = initialisers(4)%data )
641
642 call this%set_hyperparams( &
643 momentum = momentum, &
644 epsilon = epsilon, &
645 gamma_init_mean = 1.0_real32, &
646 gamma_init_std = 0.0_real32, &
647 beta_init_mean = 0.0_real32, &
648 beta_init_std = 0.0_real32, &
649 gamma_initialiser = gamma_initialiser, &
650 beta_initialiser = beta_initialiser, &
651 moving_mean_initialiser = moving_mean_initialiser, &
652 moving_variance_initialiser = moving_variance_initialiser, &
653 verbose = verbose &
654 )
655
656 end subroutine build_from_onnx_batchnorm1d
657 !###############################################################################
658
659
660 !##############################################################################!
661 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
662 !##############################################################################!
663
664
665 !###############################################################################
666
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 subroutine forward_batchnorm1d(this, input)
667 !! Forward propagation
668 implicit none
669
670 ! Arguments
671 class(batchnorm1d_layer_type), intent(inout) :: this
672 !! Instance of the 1D batch normalisation layer
673 class(array_type), dimension(:,:), intent(in) :: input
674 !! Input values
675
676 ! Local variables
677 class(batchnorm_array_type), pointer :: ptr
678 ! Pointer array
679
680
681 3 select case(this%inference)
682 case(.true.)
683 ! Do not perform the drop operation
684
685 ptr => batchnorm_inference(input(1,1), this%params(1), &
686 this%mean(:), this%variance(:), this%epsilon &
687 )
688
689 case default
690 ! Perform the drop operation
691 ptr => batchnorm( &
692 input(1,1), this%params(1),&
693 24 this%momentum, this%mean(:), this%variance(:), this%epsilon &
694
13/26
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3 times.
3 )
695
696 end select
697 12 select type(output => this%output(1,1))
698 type is(batchnorm_array_type)
699 3 call output%assign_shallow(ptr)
700 3 output%epsilon = ptr%epsilon
701
15/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✓ Branch 14 taken 18 times.
✓ Branch 15 taken 3 times.
✓ Branch 16 taken 1 times.
✓ Branch 17 taken 2 times.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 18 times.
✓ Branch 27 taken 3 times.
39 output%mean = ptr%mean
702
15/24
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✓ Branch 15 taken 18 times.
✓ Branch 16 taken 3 times.
✓ Branch 17 taken 1 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 18 times.
✓ Branch 28 taken 3 times.
42 output%variance = ptr%variance
703 end select
704
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
3 deallocate(ptr)
705
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
3 this%output(1,1)%is_temporary = .false.
706
707 3 end subroutine forward_batchnorm1d
708 !###############################################################################
709
710
76/178
✓ Branch 0 taken 10 times.
✓ Branch 1 taken 9 times.
✓ Branch 2 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✓ Branch 5 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 9 times.
✓ Branch 37 taken 1 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 9 times.
✓ Branch 40 taken 6 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 9 times.
✓ Branch 43 taken 6 times.
✓ Branch 44 taken 9 times.
✓ Branch 45 taken 15 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 9 times.
✓ Branch 48 taken 7 times.
✓ Branch 49 taken 15 times.
✓ Branch 50 taken 1 times.
✓ Branch 51 taken 15 times.
✓ Branch 52 taken 1 times.
✓ Branch 53 taken 9 times.
✓ Branch 54 taken 1 times.
✓ Branch 55 taken 9 times.
✓ Branch 56 taken 1 times.
✓ Branch 57 taken 9 times.
✓ Branch 58 taken 9 times.
✓ Branch 59 taken 9 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 9 times.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✓ Branch 64 taken 9 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 9 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 9 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 15 times.
✓ Branch 71 taken 10 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 6 times.
✓ Branch 74 taken 16 times.
✗ Branch 75 not taken.
✗ Branch 76 not taken.
✓ Branch 77 taken 6 times.
✓ Branch 78 taken 10 times.
✗ Branch 79 not taken.
✓ Branch 80 taken 10 times.
✗ Branch 81 not taken.
✗ Branch 82 not taken.
✓ Branch 83 taken 10 times.
✗ Branch 84 not taken.
✗ Branch 85 not taken.
✓ Branch 86 taken 9 times.
✓ Branch 87 taken 1 times.
✓ Branch 88 taken 9 times.
✓ Branch 89 taken 1 times.
✗ Branch 90 not taken.
✗ Branch 91 not taken.
✗ Branch 92 not taken.
✗ Branch 93 not taken.
✗ Branch 94 not taken.
✗ Branch 95 not taken.
✗ Branch 96 not taken.
✗ Branch 97 not taken.
✗ Branch 98 not taken.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✗ Branch 101 not taken.
✗ Branch 102 not taken.
✗ Branch 103 not taken.
✗ Branch 104 not taken.
✗ Branch 105 not taken.
✗ Branch 106 not taken.
✗ Branch 107 not taken.
✗ Branch 108 not taken.
✗ Branch 109 not taken.
✗ Branch 110 not taken.
✗ Branch 111 not taken.
✓ Branch 112 taken 6 times.
✗ Branch 113 not taken.
✓ Branch 114 taken 6 times.
✗ Branch 115 not taken.
✓ Branch 116 taken 6 times.
✗ Branch 117 not taken.
✗ Branch 118 not taken.
✗ Branch 119 not taken.
✓ Branch 121 taken 6 times.
✗ Branch 122 not taken.
✓ Branch 123 taken 6 times.
✗ Branch 124 not taken.
✓ Branch 125 taken 6 times.
✗ Branch 126 not taken.
✗ Branch 127 not taken.
✓ Branch 128 taken 6 times.
✓ Branch 129 taken 6 times.
✗ Branch 130 not taken.
✗ Branch 131 not taken.
✓ Branch 132 taken 6 times.
✓ Branch 133 taken 6 times.
✗ Branch 134 not taken.
✗ Branch 135 not taken.
✓ Branch 136 taken 6 times.
✓ Branch 137 taken 6 times.
✗ Branch 138 not taken.
✗ Branch 139 not taken.
✓ Branch 140 taken 6 times.
✓ Branch 142 taken 6 times.
✗ Branch 143 not taken.
✗ Branch 144 not taken.
✓ Branch 145 taken 6 times.
✗ Branch 146 not taken.
✓ Branch 147 taken 6 times.
✓ Branch 149 taken 6 times.
✗ Branch 150 not taken.
✗ Branch 151 not taken.
✓ Branch 152 taken 6 times.
✗ Branch 153 not taken.
✓ Branch 154 taken 6 times.
✓ Branch 156 taken 6 times.
✗ Branch 157 not taken.
✓ Branch 158 taken 6 times.
✗ Branch 159 not taken.
✗ Branch 160 not taken.
✗ Branch 161 not taken.
✓ Branch 163 taken 6 times.
✗ Branch 164 not taken.
✗ Branch 165 not taken.
✓ Branch 166 taken 6 times.
✗ Branch 167 not taken.
✓ Branch 168 taken 6 times.
✓ Branch 170 taken 6 times.
✗ Branch 171 not taken.
✗ Branch 172 not taken.
✓ Branch 173 taken 6 times.
✗ Branch 174 not taken.
✓ Branch 175 taken 6 times.
✓ Branch 177 taken 6 times.
✗ Branch 178 not taken.
✓ Branch 179 taken 6 times.
✗ Branch 180 not taken.
✓ Branch 181 taken 6 times.
✗ Branch 182 not taken.
✓ Branch 183 taken 6 times.
✗ Branch 184 not taken.
119 end module athena__batchnorm1d_layer
711