GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_initialiser_glorot.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 41 51 80.4%
Functions: 0 0 -%
Branches: 155 304 51.0%

Line Branch Exec Source
1 module athena__initialiser_glorot
2 !! Module containing the implementation of the Glorot initialiser
3 !!
4 !! This module implements Glorot (Xavier) initialisation, designed to
5 !! maintain variance of gradients through layers with sigmoid/tanh.
6 !!
7 !! Mathematical operation:
8 !!
9 !! Uniform variant:
10 !! \[ W \sim \mathcal{U}(-\text{limit}, \text{limit}), \quad \text{limit} = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} \]
11 !!
12 !! Normal variant:
13 !! \[ W \sim \mathcal{N}(0, \sigma^2), \quad \sigma = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} \]
14 !!
15 !! where \(n_{\text{in}}\) is fan-in, \(n_{\text{out}}\) is fan-out.
16 !!
17 !! Rationale: Maintains variance across layers, prevents vanishing/exploding
18 !! gradients in deep networks
19 !!
20 !! Best for: Tanh, Sigmoid, Softmax activations
21 !! Reference: Glorot & Bengio (2010), AISTATS
22 use coreutils, only: real32, pi, stop_program
23 use athena__misc_types, only: base_init_type
24 implicit none
25
26
27 private
28
29 public :: glorot_uniform_init_type
30 public :: glorot_normal_init_type
31
32
33 type, extends(base_init_type) :: glorot_uniform_init_type
34 !! Type for the Glorot initialiser (uniform)
35 contains
36 procedure, pass(this) :: initialise => glorot_uniform_initialise
37 !! Initialise the weights and biases using the Glorot uniform distribution
38 end type glorot_uniform_init_type
39
40 type, extends(base_init_type) :: glorot_normal_init_type
41 !! Type for the Glorot initialiser (normal)
42 contains
43 procedure, pass(this) :: initialise => glorot_normal_initialise
44 !! Initialise the weights and biases using the Glorot normal distribution
45 end type glorot_normal_init_type
46
47
48 interface glorot_uniform_init_type
49 module function initialiser_uniform_setup() result(initialiser)
50 !! Interface for the Glorot uniform initialiser
51 type(glorot_uniform_init_type) :: initialiser
52 !! Glorot uniform initialiser object
53 end function initialiser_uniform_setup
54 end interface glorot_uniform_init_type
55
56 interface glorot_normal_init_type
57 module function initialiser_normal_setup() result(initialiser)
58 !! Interface for the Glorot normal initialiser
59 type(glorot_normal_init_type) :: initialiser
60 !! Glorot normal initialiser object
61 end function initialiser_normal_setup
62 end interface glorot_normal_init_type
63
64
65
66 contains
67
68 !###############################################################################
69 132 module function initialiser_uniform_setup() result(initialiser)
70 implicit none
71 ! Arguments
72 type(glorot_uniform_init_type) :: initialiser
73 !! Glorot uniform initialiser object
74
75 132 initialiser%name = "glorot_uniform"
76
77 132 end function initialiser_uniform_setup
78 !-------------------------------------------------------------------------------
79 4 module function initialiser_normal_setup() result(initialiser)
80 implicit none
81 ! Arguments
82 type(glorot_normal_init_type) :: initialiser
83 !! Glorot normal initialiser object
84
85 4 initialiser%name = "glorot_normal"
86
87 4 end function initialiser_normal_setup
88 !###############################################################################
89
90
91 !###############################################################################
92 201 subroutine glorot_uniform_initialise(this, input, fan_in, fan_out, spacing)
93 !! Initialise the weights and biases using the Glorot uniform distribution
94 implicit none
95
96 ! Arguments
97 class(glorot_uniform_init_type), intent(inout) :: this
98 !! Instance of the Glorot initialiser
99 real(real32), dimension(..), intent(out) :: input
100 !! Weights and biases to initialise
101 integer, optional, intent(in) :: fan_in, fan_out
102 !! Number of input and output units
103 integer, dimension(:), optional, intent(in) :: spacing
104 !! Spacing of the input and output units (not used)
105
106 ! Local variables
107 integer :: n
108 !! Number of elements in the input array
109 real(real32) :: limit
110 !! Scaling factor
111 201 real(real32), dimension(:), allocatable :: r
112 !! Temporary uniform random numbers
113
114 ! Validate inputs
115
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 201 times.
201 if(.not.present(fan_in)) &
116 call stop_program("glorot_uniform_initialise: fan_in not present")
117
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 201 times.
201 if(.not.present(fan_out)) &
118 call stop_program("glorot_uniform_initialise: fan_out not present")
119
120 201 limit = sqrt(6._real32 / real(fan_in + fan_out, real32))
121
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 207 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 207 times.
✓ Branch 4 taken 207 times.
✓ Branch 5 taken 201 times.
408 n = size(input)
122
9/14
✓ Branch 0 taken 198 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 198 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 201 times.
✓ Branch 6 taken 3 times.
✓ Branch 7 taken 198 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 201 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 201 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 201 times.
201 allocate(r(n))
123 201 call random_number(r)
124
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 201 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 201 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 201 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 201 times.
✓ Branch 12 taken 42299 times.
✓ Branch 13 taken 201 times.
42500 r = (2._real32 * r - 1._real32) * limit
125
126 ! Assign according to rank
127 select rank(input)
128 rank(0)
129
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 input = r(1)
130 rank(1)
131
11/20
✗ Branch 0 not taken.
✓ Branch 1 taken 198 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 198 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 198 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 198 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 198 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 198 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 198 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 198 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 198 times.
✓ Branch 27 taken 42296 times.
✓ Branch 28 taken 198 times.
42494 input = r
132 rank(2)
133 input = reshape(r, shape(input))
134 rank(3)
135
4/6
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
4 input = reshape(r, shape(input))
136 rank(4)
137 input = reshape(r, shape(input))
138 rank(5)
139 input = reshape(r, shape(input))
140 rank(6)
141
4/6
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
7 input = reshape(r, shape(input))
142 end select
143
144
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 201 times.
201 deallocate(r)
145
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 201 times.
201 end subroutine glorot_uniform_initialise
146 !###############################################################################
147
148
149 !###############################################################################
150 7 subroutine glorot_normal_initialise(this, input, fan_in, fan_out, spacing)
151 !! Initialise the weights and biases using the Glorot normal distribution
152 implicit none
153
154 ! Arguments
155 class(glorot_normal_init_type), intent(inout) :: this
156 !! Instance of the Glorot initialiser
157 real(real32), dimension(..), intent(out) :: input
158 !! Weights to initialise
159 integer, optional, intent(in) :: fan_in, fan_out
160 !! Number of input and output units
161 integer, dimension(:), optional, intent(in) :: spacing
162 !! Spacing of the input and output units (not used here, included for compatibility)
163
164 ! Local variables
165 integer :: n
166 !! Number of elements in the input array
167 real(real32) :: sigma
168 !! Scaling factor
169 7 real(real32), dimension(:), allocatable :: u1, u2, z
170 !! Temporary arrays for the random numbers
171
172 ! Default fallback values (to avoid division by zero)
173
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(.not.present(fan_in)) &
174 call stop_program("glorot_normal_initialise: fan_in not present")
175
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(.not.present(fan_out)) &
176 call stop_program("glorot_normal_initialise: fan_out not present")
177
178 7 sigma = sqrt(2._real32 / real(fan_in + fan_out, real32))
179
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13 times.
✓ Branch 4 taken 13 times.
✓ Branch 5 taken 7 times.
20 n = size(input)
180
21/42
✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 7 times.
✓ Branch 17 taken 7 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 7 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 7 times.
✓ Branch 34 taken 7 times.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✓ Branch 37 taken 7 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 7 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 7 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 7 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 7 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 7 times.
7 allocate(u1(n), u2(n), z(n))
181
182 7 call random_number(u1)
183 7 call random_number(u2)
184
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✓ Branch 15 taken 3470 times.
✓ Branch 16 taken 7 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3470 times.
3477 where (u1 .lt. 1.E-7_real32)
185
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
7 u1 = 1.E-7_real32
186 end where
187
188 ! Box-Muller transform for normal distribution
189
13/26
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✓ Branch 27 taken 7 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 7 times.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 33 taken 3470 times.
✓ Branch 34 taken 7 times.
3477 z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2)
190
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✓ Branch 12 taken 3470 times.
✓ Branch 13 taken 7 times.
3477 z = sigma * z
191
192 ! Assign according to rank
193 select rank(input)
194 rank(0)
195
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 input = z(1)
196 rank(1)
197
11/20
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✓ Branch 27 taken 3467 times.
✓ Branch 28 taken 4 times.
3471 input = z
198 rank(2)
199 input = reshape(z, shape(input))
200 rank(3)
201
4/6
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
4 input = reshape(z, shape(input))
202 rank(4)
203 input = reshape(z, shape(input))
204 rank(5)
205 input = reshape(z, shape(input))
206 rank(6)
207
4/6
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 6 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 6 times.
7 input = reshape(z, shape(input))
208 end select
209
210
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
7 deallocate(u1, u2, z)
211
212
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
7 end subroutine glorot_normal_initialise
213 !###############################################################################
214
215
26/36
✓ Branch 0 taken 200 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 200 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 5 taken 200 times.
✓ Branch 6 taken 198 times.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 16 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 6 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 23 taken 6 times.
✓ Branch 24 taken 4 times.
✓ Branch 25 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✓ Branch 28 taken 1 times.
✓ Branch 29 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
1270 end module athena__initialiser_glorot
216