GCC Code Coverage Report


Directory: src/athena/
File: athena_batchnorm2d_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__batchnorm2d_layer
2 !! Module containing implementation of 2D batch normalisation layer
3 !!
4 !! This module implements batch normalisation for 2D convolutional layers,
5 !! normalizing activations across the batch dimension.
6 !!
7 !! Mathematical operation (training):
8 !! \[ \mu_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} x_i \]
9 !! \[ \sigma^2_\mathcal{B} = \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_\mathcal{B})^2 \]
10 !! \[ \hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}} \]
11 !! \[ y_i = \gamma \hat{x}_i + \beta \]
12 !!
13 !! where \(\gamma, \beta\) are learnable parameters, \(\epsilon\) is stability constant
14 !!
15 !! Inference: uses running statistics
16 !! \(\mu_{\text{running}}, \sigma^2_{\text{running}}\) from training
17 !!
18 !! Benefits: Reduces internal covariate shift, enables higher learning rates,
19 !! acts as regularisation, reduces dependence on initialisation
20 !! Reference: Ioffe & Szegedy (2015), ICML
21 use coreutils, only: real32, stop_program, print_warning
22 use athena__base_layer, only: batch_layer_type, base_layer_type
23 use athena__misc_types, only: base_init_type, &
24 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
25 use diffstruc, only: array_type
26 use athena__diffstruc_extd, only: batchnorm_array_type, &
27 batchnorm, batchnorm_inference
28 implicit none
29
30
31 private
32
33 public :: batchnorm2d_layer_type
34 public :: read_batchnorm2d_layer
35
36
37 type, extends(batch_layer_type) :: batchnorm2d_layer_type
38 !! Type for 2D batch normalisation layer with overloaded procedures
39 contains
40 procedure, pass(this) :: set_hyperparams => set_hyperparams_batchnorm2d
41 !! Set hyperparameters for 2D batch normalisation layer
42 procedure, pass(this) :: read => read_batchnorm2d
43 !! Read 2D batch normalisation layer from file
44
45 procedure, pass(this) :: forward => forward_batchnorm2d
46 !! Forward propagation derived type handler
47
48 final :: finalise_batchnorm2d
49 !! Finalise 2D batch normalisation layer
50 end type batchnorm2d_layer_type
51
52 interface batchnorm2d_layer_type
53 !! Interface for setting up the 2D batch normalisation layer
54 module function layer_setup( &
55 input_shape, &
56 momentum, epsilon, &
57 gamma_init_mean, gamma_init_std, &
58 beta_init_mean, beta_init_std, &
59 gamma_initialiser, beta_initialiser, &
60 moving_mean_initialiser, moving_variance_initialiser, &
61 verbose &
62 ) result(layer)
63 !! Set up the 2D batch normalisation layer
64 integer, dimension(:), optional, intent(in) :: input_shape
65 !! Input shape
66 real(real32), optional, intent(in) :: momentum, epsilon
67 !! Momentum and epsilon
68 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
69 !! Gamma initialisation mean and standard deviation
70 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
71 !! Beta initialisation mean and standard deviation
72 class(*), optional, intent(in) :: &
73 gamma_initialiser, beta_initialiser, &
74 moving_mean_initialiser, moving_variance_initialiser
75 !! Initialisers
76 integer, optional, intent(in) :: verbose
77 !! Verbosity level
78 type(batchnorm2d_layer_type) :: layer
79 !! Instance of the 2D batch normalisation layer
80 end function layer_setup
81 end interface batchnorm2d_layer_type
82
83
84
85 contains
86
87 !###############################################################################
88 subroutine finalise_batchnorm2d(this)
89 !! Finalise 2D batch normalisation layer
90 implicit none
91
92 ! Arguments
93 type(batchnorm2d_layer_type), intent(inout) :: this
94 !! Instance of the 2D batch normalisation layer
95
96 if(allocated(this%mean)) deallocate(this%mean)
97 if(allocated(this%variance)) deallocate(this%variance)
98 if(allocated(this%input_shape)) deallocate(this%input_shape)
99 if(allocated(this%output)) deallocate(this%output)
100
101 end subroutine finalise_batchnorm2d
102 !###############################################################################
103
104
105 !##############################################################################!
106 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
107 !##############################################################################!
108
109
110 !###############################################################################
111 module function layer_setup( &
112 input_shape, &
113 momentum, epsilon, &
114 gamma_init_mean, gamma_init_std, &
115 beta_init_mean, beta_init_std, &
116 gamma_initialiser, beta_initialiser, &
117 moving_mean_initialiser, moving_variance_initialiser, &
118 verbose &
119 ) result(layer)
120 !! Set up the 2D batch normalisation layer
121 use athena__initialiser, only: initialiser_setup
122 implicit none
123
124 ! Arguments
125 integer, dimension(:), optional, intent(in) :: input_shape
126 !! Input shape
127 real(real32), optional, intent(in) :: momentum, epsilon
128 !! Momentum and epsilon
129 real(real32), optional, intent(in) :: gamma_init_mean, gamma_init_std
130 !! Gamma initialisation mean and standard deviation
131 real(real32), optional, intent(in) :: beta_init_mean, beta_init_std
132 !! Beta initialisation mean and standard deviation
133 class(*), optional, intent(in) :: &
134 gamma_initialiser, beta_initialiser, &
135 moving_mean_initialiser, moving_variance_initialiser
136 !! Initialisers
137 integer, optional, intent(in) :: verbose
138 !! Verbosity level
139
140 type(batchnorm2d_layer_type) :: layer
141 !! Instance of the 2D batch normalisation layer
142
143 ! Local variables
144 integer :: verbose_ = 0
145 !! Verbosity level
146 class(base_init_type), allocatable :: &
147 gamma_initialiser_, beta_initialiser_, &
148 moving_mean_initialiser_, moving_variance_initialiser_
149 !! Initialisers
150
151
152 if(present(verbose)) verbose_ = verbose
153
154
155 !---------------------------------------------------------------------------
156 ! Set up momentum and epsilon
157 !---------------------------------------------------------------------------
158 if(present(momentum))then
159 layer%momentum = momentum
160 else
161 layer%momentum = 0._real32
162 end if
163 if(present(epsilon))then
164 layer%epsilon = epsilon
165 else
166 layer%epsilon = 1.E-5_real32
167 end if
168
169
170 !---------------------------------------------------------------------------
171 ! Set up initialiser mean and standard deviations
172 !---------------------------------------------------------------------------
173 if(present(gamma_init_mean)) layer%gamma_init_mean = gamma_init_mean
174 if(present(gamma_init_std)) layer%gamma_init_std = gamma_init_std
175 if(present(beta_init_mean)) layer%beta_init_mean = beta_init_mean
176 if(present(beta_init_std)) layer%beta_init_std = beta_init_std
177
178
179 !---------------------------------------------------------------------------
180 ! Define gamma and beta initialisers
181 !---------------------------------------------------------------------------
182 if(present(gamma_initialiser))then
183 gamma_initialiser_ = initialiser_setup(gamma_initialiser)
184 end if
185 if(present(beta_initialiser))then
186 beta_initialiser_ = initialiser_setup(beta_initialiser)
187 end if
188 if(present(moving_mean_initialiser))then
189 moving_mean_initialiser_ = initialiser_setup(moving_mean_initialiser)
190 end if
191 if(present(moving_variance_initialiser))then
192 moving_variance_initialiser_ = initialiser_setup(moving_variance_initialiser)
193 end if
194
195
196 !---------------------------------------------------------------------------
197 ! Set hyperparameters
198 !---------------------------------------------------------------------------
199 call layer%set_hyperparams( &
200 momentum = layer%momentum, epsilon = layer%epsilon, &
201 gamma_init_mean = layer%gamma_init_mean, &
202 gamma_init_std = layer%gamma_init_std, &
203 beta_init_mean = layer%beta_init_mean, &
204 beta_init_std = layer%beta_init_std, &
205 gamma_initialiser = gamma_initialiser_, &
206 beta_initialiser = beta_initialiser_, &
207 moving_mean_initialiser = moving_mean_initialiser_, &
208 moving_variance_initialiser = moving_variance_initialiser_, &
209 verbose = verbose_ &
210 )
211
212
213 !---------------------------------------------------------------------------
214 ! initialise layer shape
215 !---------------------------------------------------------------------------
216 if(present(input_shape)) call layer%init(input_shape=input_shape)
217
218 end function layer_setup
219 !###############################################################################
220
221
222 !###############################################################################
223 subroutine set_hyperparams_batchnorm2d( &
224 this, &
225 momentum, epsilon, &
226 gamma_init_mean, gamma_init_std, &
227 beta_init_mean, beta_init_std, &
228 gamma_initialiser, beta_initialiser, &
229 moving_mean_initialiser, moving_variance_initialiser, &
230 verbose )
231 !! Set hyperparameters for 2D batch normalisation layer
232 use athena__initialiser, only: initialiser_setup
233 implicit none
234
235 ! Arguments
236 class(batchnorm2d_layer_type), intent(inout) :: this
237 !! Instance of the 2D batch normalisation layer
238 real(real32), intent(in) :: momentum, epsilon
239 !! Momentum and epsilon
240 real(real32), intent(in) :: gamma_init_mean, gamma_init_std
241 !! Gamma initialisation mean and standard deviation
242 real(real32), intent(in) :: beta_init_mean, beta_init_std
243 !! Beta initialisation mean and standard deviation
244 class(base_init_type), allocatable, intent(in) :: &
245 gamma_initialiser, beta_initialiser
246 !! Gamma and beta initialisers
247 class(base_init_type), allocatable, intent(in) :: &
248 moving_mean_initialiser, moving_variance_initialiser
249 !! Moving mean and variance initialisers
250 integer, optional, intent(in) :: verbose
251 !! Verbosity level
252
253 this%name = "batchnorm2d"
254 this%type = "batc"
255 this%input_rank = 3
256 this%output_rank = 3
257 this%use_bias = .true.
258 this%momentum = momentum
259 this%epsilon = epsilon
260 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
261 if(.not.allocated(gamma_initialiser))then
262 this%kernel_init = initialiser_setup('ones')
263 else
264 allocate(this%kernel_init, source=gamma_initialiser)
265 end if
266 if(allocated(this%bias_init)) deallocate(this%bias_init)
267 if(.not.allocated(beta_initialiser))then
268 this%bias_init = initialiser_setup('zeros')
269 else
270 allocate(this%bias_init, source=beta_initialiser)
271 end if
272 if(.not.allocated(moving_mean_initialiser))then
273 this%moving_mean_init = initialiser_setup('zeros')
274 else
275 this%moving_mean_init = moving_mean_initialiser
276 end if
277 if(.not.allocated(moving_variance_initialiser))then
278 this%moving_variance_init = initialiser_setup('ones')
279 else
280 this%moving_variance_init = moving_variance_initialiser
281 end if
282 this%gamma_init_mean = gamma_init_mean
283 this%gamma_init_std = gamma_init_std
284 this%beta_init_mean = beta_init_mean
285 this%beta_init_std = beta_init_std
286 this%kernel_init%mean = this%gamma_init_mean
287 this%kernel_init%std = this%gamma_init_std
288 this%bias_init%mean = this%beta_init_mean
289 this%bias_init%std = this%beta_init_std
290 if(present(verbose))then
291 if(abs(verbose).gt.0)then
292 write(*,'("BATCHNORM2D gamma (kernel) initialiser: ",A)') &
293 trim(this%kernel_init%name)
294 write(*,'("BATCHNORM2D beta (bias) initialiser: ",A)') &
295 trim(this%bias_init%name)
296 write(*,'("BATCHNORM2D moving mean initialiser: ",A)') &
297 trim(this%moving_mean_init%name)
298 write(*,'("BATCHNORM2D moving variance initialiser: ",A)') &
299 trim(this%moving_variance_init%name)
300 end if
301 end if
302
303 end subroutine set_hyperparams_batchnorm2d
304 !###############################################################################
305
306
307 !##############################################################################!
308 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
309 !##############################################################################!
310
311
312 !###############################################################################
313 subroutine read_batchnorm2d(this, unit, verbose)
314 !! Read 2D batch normalisation layer from file
315 use athena__tools_infile, only: assign_val, assign_vec, move
316 use coreutils, only: to_lower, to_upper, icount
317 use athena__initialiser, only: initialiser_setup
318 implicit none
319
320 ! Arguments
321 class(batchnorm2d_layer_type), intent(inout) :: this
322 !! Instance of the 2D batch normalisation layer
323 integer, intent(in) :: unit
324 !! File unit
325 integer, optional, intent(in) :: verbose
326 !! Verbosity level
327
328 ! Local variables
329 integer :: stat, verbose_ = 0
330 !! Status and verbosity level
331 integer :: i, j, k, l, c, itmp1, iline, final_line
332 !! Loop variables and temporary integer
333 integer :: num_channels
334 !! Number of channels
335 real(real32) :: momentum = 0._real32, epsilon = 1.E-5_real32
336 !! Momentum and epsilon
337 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
338 !! Initialisers
339 class(base_init_type), allocatable :: &
340 moving_mean_initialiser, moving_variance_initialiser
341 !! Moving mean and variance initialisers
342 character(14) :: gamma_initialiser_name='', beta_initialiser_name=''
343 !! Initialisers
344 character(14) :: &
345 moving_mean_initialiser_name='', &
346 moving_variance_initialiser_name=''
347 !! Moving mean and variance initialisers
348 character(256) :: buffer, tag, err_msg
349 !! Buffer, tag, and error message
350
351 integer, dimension(3) :: input_shape
352 !! Input shape
353 real(real32), allocatable, dimension(:) :: data_list
354 !! Data list
355 integer, dimension(2) :: param_lines
356 !! Lines where parameters are found
357
358
359 ! Initialise optional arguments
360 !---------------------------------------------------------------------------
361 if(present(verbose)) verbose_ = verbose
362
363
364 ! Loop over tags in layer card
365 !---------------------------------------------------------------------------
366 iline = 0
367 param_lines = 0
368 final_line = 0
369 tag_loop: do
370
371 ! Check for end of file
372 !------------------------------------------------------------------------
373 read(unit,'(A)',iostat=stat) buffer
374 if(stat.ne.0)then
375 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
376 to_upper(this%name)
377 call stop_program(err_msg)
378 return
379 end if
380 if(trim(adjustl(buffer)).eq."") cycle tag_loop
381
382 ! Check for end of layer card
383 !------------------------------------------------------------------------
384 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
385 final_line = iline
386 backspace(unit)
387 exit tag_loop
388 end if
389 iline = iline + 1
390
391 tag=trim(adjustl(buffer))
392 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
393
394 ! Read parameters from save file
395 !------------------------------------------------------------------------
396 select case(trim(tag))
397 case("INPUT_SHAPE")
398 call assign_vec(buffer, input_shape, itmp1)
399 case("MOMENTUM")
400 call assign_val(buffer, momentum, itmp1)
401 case("EPSILON")
402 call assign_val(buffer, epsilon, itmp1)
403 case("NUM_CHANNELS")
404 call assign_val(buffer, num_channels, itmp1)
405 write(0,*) "NUM_CHANNELS and INPUT_SHAPE are conflicting parameters"
406 write(0,*) "NUM_CHANNELS will be ignored"
407 case("GAMMA_INITIALISER", "KERNEL_INITIALISER")
408 if(param_lines(1).ne.0)then
409 write(err_msg,'("GAMMA and GAMMA_INITIALISER defined. Using GAMMA only.")')
410 call print_warning(err_msg)
411 end if
412 call assign_val(buffer, gamma_initialiser_name, itmp1)
413 case("BETA_INITIALISER", "BIAS_INITIALISER")
414 if(param_lines(2).ne.0)then
415 write(err_msg,'("BETA and BETA_INITIALISER defined. Using BETA only.")')
416 call print_warning(err_msg)
417 end if
418 call assign_val(buffer, beta_initialiser_name, itmp1)
419 case("MOVING_MEAN_INITIALISER")
420 call assign_val(buffer, moving_mean_initialiser_name, itmp1)
421 case("MOVING_VARIANCE_INITIALISER")
422 call assign_val(buffer, moving_variance_initialiser_name, itmp1)
423 case("GAMMA")
424 gamma_initialiser_name = 'zeros'
425 param_lines(1) = iline
426 case("BETA")
427 beta_initialiser_name = 'zeros'
428 param_lines(2) = iline
429 case default
430 ! Don't look for "e" due to scientific notation of numbers
431 ! ... i.e. exponent (E+00)
432 if(scan(to_lower(trim(adjustl(buffer))),&
433 'abcdfghijklmnopqrstuvwxyz').eq.0)then
434 cycle tag_loop
435 elseif(tag(:3).eq.'END')then
436 cycle tag_loop
437 end if
438 write(err_msg,'("Unrecognised line in input file: ",A)') &
439 trim(adjustl(buffer))
440 call stop_program(err_msg)
441 return
442 end select
443 end do tag_loop
444 gamma_initialiser = initialiser_setup(gamma_initialiser_name)
445 beta_initialiser = initialiser_setup(beta_initialiser_name)
446 moving_mean_initialiser = initialiser_setup(moving_mean_initialiser_name)
447 moving_variance_initialiser = initialiser_setup(moving_variance_initialiser_name)
448
449
450 ! Set hyperparameters and initialise layer
451 !---------------------------------------------------------------------------
452 num_channels = input_shape(size(input_shape,1))
453 call this%set_hyperparams( &
454 momentum = momentum, &
455 epsilon = epsilon, &
456 gamma_init_mean = this%gamma_init_mean, &
457 gamma_init_std = this%gamma_init_std, &
458 beta_init_mean = this%beta_init_mean, &
459 beta_init_std = this%beta_init_std, &
460 gamma_initialiser = gamma_initialiser, &
461 beta_initialiser = beta_initialiser, &
462 moving_mean_initialiser = moving_mean_initialiser, &
463 moving_variance_initialiser = moving_variance_initialiser, &
464 verbose = verbose_ &
465 )
466 call this%init(input_shape = input_shape)
467
468
469 ! Check if WEIGHTS card was found
470 !---------------------------------------------------------------------------
471 allocate(data_list(num_channels), source=0._real32)
472 do i = 2, 1, -1
473 if(param_lines(i).eq.0) cycle
474 call move(unit, param_lines(i) - iline, iostat=stat)
475 iline = param_lines(i) + 1
476 c = 1
477 k = 1
478 data_list = 0._real32
479 data_concat_loop: do while(c.le.num_channels)
480 iline = iline + 1
481 read(unit,'(A)',iostat=stat) buffer
482 if(stat.ne.0) exit data_concat_loop
483 k = icount(buffer)
484 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
485 c = c + k
486 end do data_concat_loop
487 read(unit,'(A)',iostat=stat) buffer
488 select case(i)
489 case(1) ! gamma
490 this%params(1)%val(1:this%num_channels,1) = data_list
491 if(trim(adjustl(buffer)).ne."END GAMMA")then
492 write(err_msg,'("END GAMMA not where expected: ",A)') &
493 trim(adjustl(buffer))
494 call stop_program(err_msg)
495 return
496 end if
497 case(2) ! beta
498 this%params(1)%val(this%num_channels+1:this%num_channels*2,1) = &
499 data_list
500 if(trim(adjustl(buffer)).ne."END BETA")then
501 write(err_msg,'("END BETA not where expected: ",A)') &
502 trim(adjustl(buffer))
503 call stop_program(err_msg)
504 return
505 end if
506 end select
507 end do
508 deallocate(data_list)
509
510
511 ! Check for end of layer card
512 !---------------------------------------------------------------------------
513 call move(unit, final_line - iline, iostat=stat)
514 read(unit,'(A)') buffer
515 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
516 write(0,*) trim(adjustl(buffer))
517 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
518 call stop_program(err_msg)
519 return
520 end if
521
522 end subroutine read_batchnorm2d
523 !###############################################################################
524
525
526 !###############################################################################
527 function read_batchnorm2d_layer(unit, verbose) result(layer)
528 implicit none
529 integer, intent(in) :: unit
530 integer, optional, intent(in) :: verbose
531 class(base_layer_type), allocatable :: layer
532
533 integer :: verbose_ = 0
534
535 if(present(verbose)) verbose_ = verbose
536 allocate(layer, source=batchnorm2d_layer_type())
537 call layer%read(unit, verbose=verbose_)
538
539 end function read_batchnorm2d_layer
540 !###############################################################################
541
542
543 !###############################################################################
544 subroutine build_from_onnx_batchnorm2d( &
545 this, node, initialisers, value_info, verbose &
546 )
547 !! Read ONNX attributes for 2D batch normalisation layer
548 use athena__initialiser_data, only: data_init_type
549 implicit none
550
551 ! Arguments
552 class(batchnorm2d_layer_type), intent(inout) :: this
553 !! Instance of the 2D batch normalisation layer
554 type(onnx_node_type), intent(in) :: node
555 !! ONNX node information
556 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
557 !! ONNX initialiser information
558 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
559 !! ONNX value info
560 integer, intent(in) :: verbose
561 !! Verbosity level
562
563 ! Local variables
564 integer :: i
565 !! Loop index
566 real(real32) :: epsilon, momentum
567 !! Epsilon and momentum values
568 character(256) :: val
569 !! Attribute value
570 class(base_init_type), allocatable :: gamma_initialiser, beta_initialiser
571 class(base_init_type), allocatable :: &
572 moving_mean_initialiser, moving_variance_initialiser
573
574 ! Set default values
575 epsilon = 1.E-5_real32
576 momentum = 0.9_real32
577
578 ! Parse ONNX attributes
579 do i = 1, size(node%attributes)
580 val = node%attributes(i)%val
581 select case(trim(adjustl(node%attributes(i)%name)))
582 case("epsilon")
583 read(val,*) epsilon
584 case("momentum")
585 read(val,*) momentum
586 case default
587 ! Do nothing
588 write(0,*) "WARNING: Unrecognised attribute in ONNX BATCHNORM2D &
589 &layer: ", trim(adjustl(node%attributes(i)%name))
590 end select
591 end do
592
593 ! Check for 4 initialisers: gamma, beta, mean, variance
594 if(size(initialisers).ne.4)then
595 call stop_program("ONNX BATCHNORM2D layer requires 4 initialisers &
596 &(gamma, beta, mean, variance)")
597 return
598 end if
599
600 ! ONNX BatchNormalization order: gamma, beta, mean, variance
601 gamma_initialiser = data_init_type( data = initialisers(1)%data )
602 beta_initialiser = data_init_type( data = initialisers(2)%data )
603 moving_mean_initialiser = data_init_type( data = initialisers(3)%data )
604 moving_variance_initialiser = data_init_type( data = initialisers(4)%data )
605
606 call this%set_hyperparams( &
607 momentum = momentum, &
608 epsilon = epsilon, &
609 gamma_init_mean = 1.0_real32, &
610 gamma_init_std = 0.0_real32, &
611 beta_init_mean = 0.0_real32, &
612 beta_init_std = 0.0_real32, &
613 gamma_initialiser = gamma_initialiser, &
614 beta_initialiser = beta_initialiser, &
615 moving_mean_initialiser = moving_mean_initialiser, &
616 moving_variance_initialiser = moving_variance_initialiser, &
617 verbose = verbose &
618 )
619
620 end subroutine build_from_onnx_batchnorm2d
621 !###############################################################################
622
623
624 !##############################################################################!
625 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
626 !##############################################################################!
627
628
629 !###############################################################################
630 subroutine forward_batchnorm2d(this, input)
631 !! Forward propagation
632 implicit none
633
634 ! Arguments
635 class(batchnorm2d_layer_type), intent(inout) :: this
636 !! Instance of the 2D batch normalisation layer
637 class(array_type), dimension(:,:), intent(in) :: input
638 !! Input values
639
640 ! Local variables
641 class(batchnorm_array_type), pointer :: ptr
642 ! Pointer array
643
644
645 select case(this%inference)
646 case(.true.)
647 ! Do not perform the drop operation
648
649 ptr => batchnorm_inference(input(1,1), this%params(1), &
650 this%mean(:), this%variance(:), this%epsilon &
651 )
652
653 case default
654 ! Perform the drop operation
655 ptr => batchnorm( &
656 input(1,1), this%params(1),&
657 this%momentum, this%mean(:), this%variance(:), this%epsilon &
658 )
659
660 end select
661 select type(output => this%output(1,1))
662 type is(batchnorm_array_type)
663 call output%assign_shallow(ptr)
664 output%epsilon = ptr%epsilon
665 output%mean = ptr%mean
666 output%variance = ptr%variance
667 end select
668 deallocate(ptr)
669 this%output(1,1)%is_temporary = .false.
670
671 end subroutine forward_batchnorm2d
672 !###############################################################################
673
674 end module athena__batchnorm2d_layer
675