GCC Code Coverage Report


Directory: src/athena/
File: athena_batchnorm3d_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__batchnorm3d_layer
2 !! Module containing implementation of 3D batch normalisation layers
3 !!
4 !! This module implements batch normalisation for 3D convolutional layers,
5 !! normalizing activations across the batch dimension.
6 !!
7 !! Mathematical operation (training):
8 !! \[ \mu_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} x_i \]
9 !! \[ \sigma^2_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_\mathcal{B})^2 \]
10 !! \[ \hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}} \]
11 !! \[ y_i = \gamma \hat{x}_i + \beta \]
12 !!
13 !! where \(\gamma, \beta\) are learnable parameters, \(\epsilon\) is stability constant
14 !!
15 !! Inference: uses running statistics
16 !! \(\mu_{\text{running}}, \sigma^2_{\text{running}}\) from training
17 !!
18 !! Benefits: Reduces internal covariate shift, enables higher learning rates,
19 !! acts as regularisation, reduces dependence on initialisation
20 !! Reference: Ioffe & Szegedy (2015), ICML
21 use coreutils, only: real32, stop_program, print_warning
22 use athena__base_layer, only: batch_layer_type, base_layer_type
23 use athena__misc_types, only: base_init_type, &
24 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
25 use diffstruc, only: array_type
26 use athena__diffstruc_extd, only: batchnorm_array_type, &
27 batchnorm, batchnorm_inference
28 implicit none
29
30
31 private
32
33 public :: batchnorm3d_layer_type
34 public :: read_batchnorm3d_layer
35
36
37 type, extends(batch_layer_type) :: batchnorm3d_layer_type
38 !! Type for 3D batch normalisation layer with overloaded procedures
39 contains
40 procedure, pass(this) :: set_hyperparams => set_hyperparams_batchnorm3d
41 !! Set hyperparameters for 3D batch normalisation layer
42 procedure, pass(this) :: read => read_batchnorm3d
43 !! Read 3D batch normalisation layer from file
44
45 procedure, pass(this) :: forward => forward_batchnorm3d
46 !! Forward propagation derived type handler
47
48 final :: finalise_batchnorm3d
49 !! Finalise 3D batch normalisation layer
50 end type batchnorm3d_layer_type
51
52 interface batchnorm3d_layer_type
53 !! Interface for setting up the 3D batch normalisation layer
54 module function layer_setup( &
55 input_shape, &
56 momentum, epsilon, &
57 gamma_init_mean, gamma_init_std, &
58 beta_init_mean, beta_init_std, &
59 gamma_initialiser, beta_initialiser, &
60 moving_mean_initialiser, moving_variance_initialiser, &
61 verbose &
62 ) result(layer)
63 !! Set up the 3D batch normalisation layer
64 integer, dimension(:), optional, intent(in) :: input_shape
65 !! Input shape
66 real(real32), optional, intent(in) :: momentum, epsilon
67 !! Momentum and epsilon
68 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
69 !! Gamma initialisation mean and standard deviation
70 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
71 !! Beta initialisation mean and standard deviation
72 class(*), optional, intent(in) :: &
73 gamma_initialiser, beta_initialiser, &
74 moving_mean_initialiser, moving_variance_initialiser
75 !! Initialisers
76 integer, optional, intent(in) :: verbose
77 !! Verbosity level
78 type(batchnorm3d_layer_type) :: layer
79 !! Instance of the 3D batch normalisation layer
80 end function layer_setup
81 end interface batchnorm3d_layer_type
82
83
84
85 contains
86
87 !###############################################################################
88 subroutine finalise_batchnorm3d(this)
89 !! Finalise 3D batch normalisation layer
90 implicit none
91
92 ! Arguments
93 type(batchnorm3d_layer_type), intent(inout) :: this
94 !! Instance of the 3D batch normalisation layer
95
96 if(allocated(this%mean)) deallocate(this%mean)
97 if(allocated(this%variance)) deallocate(this%variance)
98 if(allocated(this%input_shape)) deallocate(this%input_shape)
99 if(allocated(this%output)) deallocate(this%output)
100
101 end subroutine finalise_batchnorm3d
102 !###############################################################################
103
104
105 !##############################################################################!
106 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
107 !##############################################################################!
108
109
110 !###############################################################################
111 module function layer_setup( &
112 input_shape, &
113 momentum, epsilon, &
114 gamma_init_mean, gamma_init_std, &
115 beta_init_mean, beta_init_std, &
116 gamma_initialiser, beta_initialiser, &
117 moving_mean_initialiser, moving_variance_initialiser, &
118 verbose &
119 ) result(layer)
120 !! Set up the 3D batch normalisation layer
121 use athena__initialiser, only: initialiser_setup
122 implicit none
123
124 ! Arguments
125 integer, dimension(:), optional, intent(in) :: input_shape
126 !! Input shape
127 real(real32), optional, intent(in) :: momentum, epsilon
128 !! Momentum and epsilon
129 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
130 !! Gamma initialisation mean and standard deviation
131 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
132 !! Beta initialisation mean and standard deviation
133 class(*), optional, intent(in) :: &
134 gamma_initialiser, beta_initialiser, &
135 moving_mean_initialiser, moving_variance_initialiser
136 !! Initialisers
137 integer, optional, intent(in) :: verbose
138 !! Verbosity level
139
140 type(batchnorm3d_layer_type) :: layer
141 !! Instance of the 3D batch normalisation layer
142
143 ! Local variables
144 integer :: verbose_ = 0
145 !! Verbosity level
146 class(base_init_type), allocatable :: &
147 gamma_initialiser_, beta_initialiser_, &
148 moving_mean_initialiser_, moving_variance_initialiser_
149 !! Initialisers
150
151
152 if(present(verbose)) verbose_ = verbose
153
154
155 !---------------------------------------------------------------------------
156 ! Set up momentum and epsilon
157 !---------------------------------------------------------------------------
158 if(present(momentum))then
159 layer%momentum = momentum
160 else
161 layer%momentum = 0._real32
162 end if
163 if(present(epsilon))then
164 layer%epsilon = epsilon
165 else
166 layer%epsilon = 1.E-5_real32
167 end if
168
169
170 !---------------------------------------------------------------------------
171 ! Set up initialiser mean and standard deviations
172 !---------------------------------------------------------------------------
173 if(present(gamma_init_mean)) layer%gamma_init_mean = gamma_init_mean
174 if(present(gamma_init_std)) layer%gamma_init_std = gamma_init_std
175 if(present(beta_init_mean)) layer%beta_init_mean = beta_init_mean
176 if(present(beta_init_std)) layer%beta_init_std = beta_init_std
177
178
179 !---------------------------------------------------------------------------
180 ! Define gamma and beta initialisers
181 !---------------------------------------------------------------------------
182 if(present(gamma_initialiser))then
183 gamma_initialiser_ = initialiser_setup(gamma_initialiser)
184 end if
185 if(present(beta_initialiser))then
186 beta_initialiser_ = initialiser_setup(beta_initialiser)
187 end if
188 if(present(moving_mean_initialiser))then
189 moving_mean_initialiser_ = initialiser_setup(moving_mean_initialiser)
190 end if
191 if(present(moving_variance_initialiser))then
192 moving_variance_initialiser_ = initialiser_setup(moving_variance_initialiser)
193 end if
194
195
196 !---------------------------------------------------------------------------
197 ! Set hyperparameters
198 !---------------------------------------------------------------------------
199 call layer%set_hyperparams( &
200 momentum = layer%momentum, epsilon = layer%epsilon, &
201 gamma_init_mean = layer%gamma_init_mean, &
202 gamma_init_std = layer%gamma_init_std, &
203 beta_init_mean = layer%beta_init_mean, &
204 beta_init_std = layer%beta_init_std, &
205 gamma_initialiser = gamma_initialiser_, &
206 beta_initialiser = beta_initialiser_, &
207 moving_mean_initialiser = moving_mean_initialiser_, &
208 moving_variance_initialiser = moving_variance_initialiser_, &
209 verbose = verbose_ &
210 )
211
212
213 !---------------------------------------------------------------------------
214 ! Initialise layer shape
215 !---------------------------------------------------------------------------
216 if(present(input_shape)) call layer%init(input_shape=input_shape)
217
218 end function layer_setup
219 !###############################################################################
220
221
222 !###############################################################################
223 subroutine set_hyperparams_batchnorm3d( &
224 this, &
225 momentum, epsilon, &
226 gamma_init_mean, gamma_init_std, &
227 beta_init_mean, beta_init_std, &
228 gamma_initialiser, beta_initialiser, &
229 moving_mean_initialiser, moving_variance_initialiser, &
230 verbose )
231 !! Set hyperparameters for 3D batch normalisation layer
232 use athena__initialiser, only: initialiser_setup
233 implicit none
234
235 ! Arguments
236 class(batchnorm3d_layer_type), intent(inout) :: this
237 !! Instance of the 3D batch normalisation layer
238 real(real32), intent(in) :: momentum, epsilon
239 !! Momentum and epsilon
240 real(real32), intent(in) :: gamma_init_mean, gamma_init_std
241 !! Gamma initialisation mean and standard deviation
242 real(real32), intent(in) :: beta_init_mean, beta_init_std
243 !! Beta initialisation mean and standard deviation
244 class(base_init_type), allocatable, intent(in) :: &
245 gamma_initialiser, beta_initialiser
246 !! Gamma and beta initialisers
247 class(base_init_type), allocatable, intent(in) :: &
248 moving_mean_initialiser, moving_variance_initialiser
249 !! Moving mean and variance initialisers
250 integer, optional, intent(in) :: verbose
251 !! Verbosity level
252
253 this%name = "batchnorm3d"
254 this%type = "batc"
255 this%input_rank = 4
256 this%output_rank = 4
257 this%use_bias = .true.
258 this%momentum = momentum
259 this%epsilon = epsilon
260 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
261 if(.not.allocated(gamma_initialiser))then
262 this%kernel_init = initialiser_setup('ones')
263 else
264 allocate(this%kernel_init, source=gamma_initialiser)
265 end if
266 if(allocated(this%bias_init)) deallocate(this%bias_init)
267 if(.not.allocated(beta_initialiser))then
268 this%bias_init = initialiser_setup('zeros')
269 else
270 allocate(this%bias_init, source=beta_initialiser)
271 end if
272 if(.not.allocated(moving_mean_initialiser))then
273 this%moving_mean_init = initialiser_setup('zeros')
274 else
275 this%moving_mean_init = moving_mean_initialiser
276 end if
277 if(.not.allocated(moving_variance_initialiser))then
278 this%moving_variance_init = initialiser_setup('ones')
279 else
280 this%moving_variance_init = moving_variance_initialiser
281 end if
282 this%gamma_init_mean = gamma_init_mean
283 this%gamma_init_std = gamma_init_std
284 this%beta_init_mean = beta_init_mean
285 this%beta_init_std = beta_init_std
286 this%kernel_init%mean = this%gamma_init_mean
287 this%kernel_init%std = this%gamma_init_std
288 this%bias_init%mean = this%beta_init_mean
289 this%bias_init%std = this%beta_init_std
290 if(present(verbose))then
291 if(abs(verbose).gt.0)then
292 write(*,'("BATCHNORM3D gamma (kernel) initialiser: ",A)') &
293 trim(this%kernel_init%name)
294 write(*,'("BATCHNORM3D beta (bias) initialiser: ",A)') &
295 trim(this%bias_init%name)
296 write(*,'("BATCHNORM3D moving mean initialiser: ",A)') &
297 trim(this%moving_mean_init%name)
298 write(*,'("BATCHNORM3D moving variance initialiser: ",A)') &
299 trim(this%moving_variance_init%name)
300 end if
301 end if
302
303 end subroutine set_hyperparams_batchnorm3d
304 !###############################################################################
305
306
307 !##############################################################################!
308 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
309 !##############################################################################!
310
311
312 !###############################################################################
313 subroutine read_batchnorm3d(this, unit, verbose)
314 !! Read 3D batch normalisation layer from file
315 use athena__tools_infile, only: assign_val, assign_vec, move
316 use coreutils, only: to_lower, to_upper, icount
317 use athena__initialiser, only: initialiser_setup
318 implicit none
319
320 ! Arguments
321 class(batchnorm3d_layer_type), intent(inout) :: this
322 !! Instance of the 3D batch normalisation layer
323 integer, intent(in) :: unit
324 !! File unit
325 integer, optional, intent(in) :: verbose
326 !! Verbosity level
327
328 ! Local variables
329 integer :: stat, verbose_ = 0
330 !! Status and verbosity level
331 integer :: i, j, k, l, c, itmp1, iline, final_line
332 !! Loop variables and temporary integer
333 integer :: num_channels
334 !! Number of channels
335 real(real32) :: momentum = 0._real32, epsilon = 1.E-5_real32
336 !! Momentum and epsilon
337 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
338 !! Initialisers
339 class(base_init_type), allocatable :: &
340 moving_mean_initialiser, moving_variance_initialiser
341 !! Moving mean and variance initialisers
342 character(14) :: gamma_initialiser_name='', beta_initialiser_name=''
343 !! Initialisers
344 character(14) :: &
345 moving_mean_initialiser_name='', &
346 moving_variance_initialiser_name=''
347 !! Moving mean and variance initialisers
348 character(256) :: buffer, tag, err_msg
349 !! Buffer, tag, and error message
350 integer, dimension(4) :: input_shape
351 !! Input shape
352 real(real32), allocatable, dimension(:) :: data_list
353 !! Data list
354 integer, dimension(2) :: param_lines
355 !! Lines where parameters are found
356
357
358 ! Initialise optional arguments
359 !---------------------------------------------------------------------------
360 if(present(verbose)) verbose_ = verbose
361
362
363 ! Loop over tags in layer card
364 !---------------------------------------------------------------------------
365 iline = 0
366 param_lines = 0
367 final_line = 0
368 tag_loop: do
369
370 ! Check for end of file
371 !------------------------------------------------------------------------
372 read(unit,'(A)',iostat=stat) buffer
373 if(stat.ne.0)then
374 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
375 to_upper(this%name)
376 call stop_program(err_msg)
377 return
378 end if
379 if(trim(adjustl(buffer)).eq."") cycle tag_loop
380
381 ! Check for end of layer card
382 !------------------------------------------------------------------------
383 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
384 final_line = iline
385 backspace(unit)
386 exit tag_loop
387 end if
388 iline = iline + 1
389
390 tag=trim(adjustl(buffer))
391 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
392
393 ! Read parameters from save file
394 !------------------------------------------------------------------------
395 select case(trim(tag))
396 case("INPUT_SHAPE")
397 call assign_vec(buffer, input_shape, itmp1)
398 case("MOMENTUM")
399 call assign_val(buffer, momentum, itmp1)
400 case("EPSILON")
401 call assign_val(buffer, epsilon, itmp1)
402 case("NUM_CHANNELS")
403 call assign_val(buffer, num_channels, itmp1)
404 write(0,*) "NUM_CHANNELS and INPUT_SHAPE are conflicting parameters"
405 write(0,*) "NUM_CHANNELS will be ignored"
406 case("GAMMA_INITIALISER", "KERNEL_INITIALISER")
407 if(param_lines(1).ne.0)then
408 write(err_msg,'("GAMMA and GAMMA_INITIALISER defined. Using GAMMA only.")')
409 call print_warning(err_msg)
410 end if
411 call assign_val(buffer, gamma_initialiser_name, itmp1)
412 case("BETA_INITIALISER", "BIAS_INITIALISER")
413 if(param_lines(2).ne.0)then
414 write(err_msg,'("BETA and BETA_INITIALISER defined. Using BETA only.")')
415 call print_warning(err_msg)
416 end if
417 call assign_val(buffer, beta_initialiser_name, itmp1)
418 case("MOVING_MEAN_INITIALISER")
419 call assign_val(buffer, moving_mean_initialiser_name, itmp1)
420 case("MOVING_VARIANCE_INITIALISER")
421 call assign_val(buffer, moving_variance_initialiser_name, itmp1)
422 case("GAMMA")
423 gamma_initialiser_name = 'zeros'
424 param_lines(1) = iline
425 case("BETA")
426 beta_initialiser_name = 'zeros'
427 param_lines(2) = iline
428 case default
429 ! Don't look for "e" due to scientific notation of numbers
430 ! ... i.e. exponent (E+00)
431 if(scan(to_lower(trim(adjustl(buffer))),&
432 'abcdfghijklmnopqrstuvwxyz').eq.0)then
433 cycle tag_loop
434 elseif(tag(:3).eq.'END')then
435 cycle tag_loop
436 end if
437 write(err_msg,'("Unrecognised line in input file: ",A)') &
438 trim(adjustl(buffer))
439 call stop_program(err_msg)
440 return
441 end select
442 end do tag_loop
443 gamma_initialiser = initialiser_setup(gamma_initialiser_name)
444 beta_initialiser = initialiser_setup(beta_initialiser_name)
445 moving_mean_initialiser = initialiser_setup(moving_mean_initialiser_name)
446 moving_variance_initialiser = initialiser_setup(moving_variance_initialiser_name)
447
448
449 ! Set hyperparameters and initialise layer
450 !---------------------------------------------------------------------------
451 num_channels = input_shape(size(input_shape,1))
452 call this%set_hyperparams( &
453 momentum = momentum, &
454 epsilon = epsilon, &
455 gamma_init_mean = this%gamma_init_mean, &
456 gamma_init_std = this%gamma_init_std, &
457 beta_init_mean = this%beta_init_mean, &
458 beta_init_std = this%beta_init_std, &
459 gamma_initialiser = gamma_initialiser, &
460 beta_initialiser = beta_initialiser, &
461 moving_mean_initialiser = moving_mean_initialiser, &
462 moving_variance_initialiser = moving_variance_initialiser, &
463 verbose = verbose_ &
464 )
465 call this%init(input_shape = input_shape)
466
467
468 ! Check if WEIGHTS card was found
469 !---------------------------------------------------------------------------
470 allocate(data_list(num_channels), source=0._real32)
471 do i = 2, 1, -1
472 if(param_lines(i).eq.0) cycle
473 call move(unit, param_lines(i) - iline, iostat=stat)
474 iline = param_lines(i) + 1
475 c = 1
476 k = 1
477 data_list = 0._real32
478 data_concat_loop: do while(c.le.num_channels)
479 iline = iline + 1
480 read(unit,'(A)',iostat=stat) buffer
481 if(stat.ne.0) exit data_concat_loop
482 k = icount(buffer)
483 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
484 c = c + k
485 end do data_concat_loop
486 read(unit,'(A)',iostat=stat) buffer
487 select case(i)
488 case(1) ! gamma
489 this%params(1)%val(1:this%num_channels,1) = data_list
490 if(trim(adjustl(buffer)).ne."END GAMMA")then
491 write(err_msg,'("END GAMMA not where expected: ",A)') &
492 trim(adjustl(buffer))
493 call stop_program(err_msg)
494 return
495 end if
496 case(2) ! beta
497 this%params(1)%val(this%num_channels+1:this%num_channels*2,1) = &
498 data_list
499 if(trim(adjustl(buffer)).ne."END BETA")then
500 write(err_msg,'("END BETA not where expected: ",A)') &
501 trim(adjustl(buffer))
502 call stop_program(err_msg)
503 return
504 end if
505 end select
506 end do
507 deallocate(data_list)
508
509
510 ! Check for end of layer card
511 !---------------------------------------------------------------------------
512 call move(unit, final_line - iline, iostat=stat)
513 read(unit,'(A)') buffer
514 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
515 write(0,*) trim(adjustl(buffer))
516 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
517 call stop_program(err_msg)
518 return
519 end if
520
521 end subroutine read_batchnorm3d
522 !###############################################################################
523
524
525 !###############################################################################
526 function read_batchnorm3d_layer(unit, verbose) result(layer)
527 !! Read 3D batch normalisation layer from file and return layer
528 implicit none
529
530 ! Arguments
531 integer, intent(in) :: unit
532 !! File unit
533 integer, optional, intent(in) :: verbose
534 !! Verbosity level
535
536 ! Local variables
537 class(base_layer_type), allocatable :: layer
538 !! Instance of the 3D batch normalisation layer
539 integer :: verbose_ = 0
540 !! Verbosity level
541
542 if(present(verbose)) verbose_ = verbose
543 allocate(layer, source=batchnorm3d_layer_type())
544 call layer%read(unit, verbose=verbose_)
545
546 end function read_batchnorm3d_layer
547 !###############################################################################
548
549
550 !###############################################################################
551 subroutine build_from_onnx_batchnorm3d( &
552 this, node, initialisers, value_info, verbose &
553 )
554 !! Read ONNX attributes for 3D batch normalisation layer
555 use athena__initialiser_data, only: data_init_type
556 implicit none
557
558 ! Arguments
559 class(batchnorm3d_layer_type), intent(inout) :: this
560 !! Instance of the 3D batch normalisation layer
561 type(onnx_node_type), intent(in) :: node
562 !! ONNX node information
563 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
564 !! ONNX initialiser information
565 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
566 !! ONNX value info
567 integer, intent(in) :: verbose
568 !! Verbosity level
569
570 ! Local variables
571 integer :: i
572 !! Loop index
573 real(real32) :: epsilon, momentum
574 !! Epsilon and momentum values
575 character(256) :: val
576 !! Attribute value
577 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
578 class(base_init_type), allocatable :: &
579 moving_mean_initialiser, moving_variance_initialiser
580
581 ! Set default values
582 epsilon = 1.E-5_real32
583 momentum = 0.9_real32
584
585 ! Parse ONNX attributes
586 do i = 1, size(node%attributes)
587 val = node%attributes(i)%val
588 select case(trim(adjustl(node%attributes(i)%name)))
589 case("epsilon")
590 read(val,*) epsilon
591 case("momentum")
592 read(val,*) momentum
593 case default
594 ! Do nothing
595 write(0,*) "WARNING: Unrecognised attribute in ONNX BATCHNORM3D &
596 &layer: ", trim(adjustl(node%attributes(i)%name))
597 end select
598 end do
599
600 ! Check for 4 initialisers: gamma, beta, mean, variance
601 if(size(initialisers).ne.4)then
602 call stop_program("ONNX BATCHNORM3D layer requires 4 initialisers &
603 &(gamma, beta, mean, variance)")
604 return
605 end if
606
607 ! ONNX BatchNormalization order: gamma, beta, mean, variance
608 gamma_initialiser = data_init_type( data = initialisers(1)%data )
609 beta_initialiser = data_init_type( data = initialisers(2)%data )
610 moving_mean_initialiser = data_init_type( data = initialisers(3)%data )
611 moving_variance_initialiser = data_init_type( data = initialisers(4)%data )
612
613 call this%set_hyperparams( &
614 momentum = momentum, &
615 epsilon = epsilon, &
616 gamma_init_mean = 1.0_real32, &
617 gamma_init_std = 0.0_real32, &
618 beta_init_mean = 0.0_real32, &
619 beta_init_std = 0.0_real32, &
620 gamma_initialiser = gamma_initialiser, &
621 beta_initialiser = beta_initialiser, &
622 moving_mean_initialiser = moving_mean_initialiser, &
623 moving_variance_initialiser = moving_variance_initialiser, &
624 verbose = verbose &
625 )
626
627 end subroutine build_from_onnx_batchnorm3d
628 !###############################################################################
629
630
631 !##############################################################################!
632 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
633 !##############################################################################!
634
635
636 !###############################################################################
637 subroutine forward_batchnorm3d(this, input)
638 !! Forward propagation
639 implicit none
640
641 ! Arguments
642 class(batchnorm3d_layer_type), intent(inout) :: this
643 !! Instance of the 3D batch normalisation layer
644 class(array_type), dimension(:,:), intent(in) :: input
645 !! Input values
646
647 ! Local variables
648 class(batchnorm_array_type), pointer :: ptr
649 ! Pointer array
650
651
652 select case(this%inference)
653 case(.true.)
654 ! Do not perform the drop operation
655
656 ptr => batchnorm_inference(input(1,1), this%params(1), &
657 this%mean(:), this%variance(:), this%epsilon &
658 )
659
660 case default
661 ! Perform the drop operation
662 ptr => batchnorm( &
663 input(1,1), this%params(1),&
664 this%momentum, this%mean(:), this%variance(:), this%epsilon &
665 )
666
667 end select
668 select type(output => this%output(1,1))
669 type is(batchnorm_array_type)
670 call output%assign_shallow(ptr)
671 output%epsilon = ptr%epsilon
672 output%mean = ptr%mean
673 output%variance = ptr%variance
674 end select
675 deallocate(ptr)
676 this%output(1,1)%is_temporary = .false.
677
678 end subroutine forward_batchnorm3d
679 !###############################################################################
680
681 end module athena__batchnorm3d_layer
682