GCC Code Coverage Report


Directory: src/athena/
File: athena_initialiser_glorot.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__initialiser_glorot
2 !! Module containing the implementation of the Glorot initialiser
3 !!
4 !! This module implements Glorot (Xavier) initialisation, designed to
5 !! maintain variance of gradients through layers with sigmoid/tanh.
6 !!
7 !! Mathematical operation:
8 !!
9 !! Uniform variant:
10 !! \[ W \sim \mathcal{U}(-\text{limit}, \text{limit}), \quad \text{limit} = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} \]
11 !!
12 !! Normal variant:
13 !! \[ W \sim \mathcal{N}(0, \sigma^2), \quad \sigma = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} \]
14 !!
15 !! where \(n_{\text{in}}\) is fan-in, \(n_{\text{out}}\) is fan-out.
16 !!
17 !! Rationale: Maintains variance across layers, prevents vanishing/exploding
18 !! gradients in deep networks
19 !!
20 !! Best for: Tanh, Sigmoid, Softmax activations
21 !! Reference: Glorot & Bengio (2010), AISTATS
22 use coreutils, only: real32, pi, stop_program
23 use athena__misc_types, only: base_init_type
24 implicit none
25
26
27 private
28
29 public :: glorot_uniform_init_type
30 public :: glorot_normal_init_type
31
32
33 type, extends(base_init_type) :: glorot_uniform_init_type
34 !! Type for the Glorot initialiser (uniform)
35 contains
36 procedure, pass(this) :: initialise => glorot_uniform_initialise
37 !! Initialise the weights and biases using the Glorot uniform distribution
38 end type glorot_uniform_init_type
39
40 type, extends(base_init_type) :: glorot_normal_init_type
41 !! Type for the Glorot initialiser (normal)
42 contains
43 procedure, pass(this) :: initialise => glorot_normal_initialise
44 !! Initialise the weights and biases using the Glorot normal distribution
45 end type glorot_normal_init_type
46
47
48 interface glorot_uniform_init_type
49 module function initialiser_uniform_setup() result(initialiser)
50 !! Interface for the Glorot uniform initialiser
51 type(glorot_uniform_init_type) :: initialiser
52 !! Glorot uniform initialiser object
53 end function initialiser_uniform_setup
54 end interface glorot_uniform_init_type
55
56 interface glorot_normal_init_type
57 module function initialiser_normal_setup() result(initialiser)
58 !! Interface for the Glorot normal initialiser
59 type(glorot_normal_init_type) :: initialiser
60 !! Glorot normal initialiser object
61 end function initialiser_normal_setup
62 end interface glorot_normal_init_type
63
64
65
66 contains
67
68 !###############################################################################
69 module function initialiser_uniform_setup() result(initialiser)
70 implicit none
71 ! Arguments
72 type(glorot_uniform_init_type) :: initialiser
73 !! Glorot uniform initialiser object
74
75 initialiser%name = "glorot_uniform"
76
77 end function initialiser_uniform_setup
78 !-------------------------------------------------------------------------------
79 module function initialiser_normal_setup() result(initialiser)
80 implicit none
81 ! Arguments
82 type(glorot_normal_init_type) :: initialiser
83 !! Glorot normal initialiser object
84
85 initialiser%name = "glorot_normal"
86
87 end function initialiser_normal_setup
88 !###############################################################################
89
90
91 !###############################################################################
92 subroutine glorot_uniform_initialise(this, input, fan_in, fan_out, spacing)
93 !! Initialise the weights and biases using the Glorot uniform distribution
94 implicit none
95
96 ! Arguments
97 class(glorot_uniform_init_type), intent(inout) :: this
98 !! Instance of the Glorot initialiser
99 real(real32), dimension(..), intent(out) :: input
100 !! Weights and biases to initialise
101 integer, optional, intent(in) :: fan_in, fan_out
102 !! Number of input and output units
103 integer, dimension(:), optional, intent(in) :: spacing
104 !! Spacing of the input and output units (not used)
105
106 ! Local variables
107 integer :: n
108 !! Number of elements in the input array
109 real(real32) :: limit
110 !! Scaling factor
111 real(real32), dimension(:), allocatable :: r
112 !! Temporary uniform random numbers
113
114 ! Validate inputs
115 if(.not.present(fan_in)) &
116 call stop_program("glorot_uniform_initialise: fan_in not present")
117 if(.not.present(fan_out)) &
118 call stop_program("glorot_uniform_initialise: fan_out not present")
119
120 limit = sqrt(6._real32 / real(fan_in + fan_out, real32))
121 n = size(input)
122 allocate(r(n))
123 call random_number(r)
124 r = (2._real32 * r - 1._real32) * limit
125
126 ! Assign according to rank
127 select rank(input)
128 rank(0)
129 input = r(1)
130 rank(1)
131 input = r
132 rank(2)
133 input = reshape(r, shape(input))
134 rank(3)
135 input = reshape(r, shape(input))
136 rank(4)
137 input = reshape(r, shape(input))
138 rank(5)
139 input = reshape(r, shape(input))
140 rank(6)
141 input = reshape(r, shape(input))
142 end select
143
144 deallocate(r)
145 end subroutine glorot_uniform_initialise
146 !###############################################################################
147
148
149 !###############################################################################
150 subroutine glorot_normal_initialise(this, input, fan_in, fan_out, spacing)
151 !! Initialise the weights and biases using the Glorot normal distribution
152 implicit none
153
154 ! Arguments
155 class(glorot_normal_init_type), intent(inout) :: this
156 !! Instance of the Glorot initialiser
157 real(real32), dimension(..), intent(out) :: input
158 !! Weights to initialise
159 integer, optional, intent(in) :: fan_in, fan_out
160 !! Number of input and output units
161 integer, dimension(:), optional, intent(in) :: spacing
162 !! Spacing of the input and output units (not used here, included for compatibility)
163
164 ! Local variables
165 integer :: n
166 !! Number of elements in the input array
167 real(real32) :: sigma
168 !! Scaling factor
169 real(real32), dimension(:), allocatable :: u1, u2, z
170 !! Temporary arrays for the random numbers
171
172 ! Default fallback values (to avoid division by zero)
173 if(.not.present(fan_in)) &
174 call stop_program("glorot_normal_initialise: fan_in not present")
175 if(.not.present(fan_out)) &
176 call stop_program("glorot_normal_initialise: fan_out not present")
177
178 sigma = sqrt(2._real32 / real(fan_in + fan_out, real32))
179 n = size(input)
180 allocate(u1(n), u2(n), z(n))
181
182 call random_number(u1)
183 call random_number(u2)
184 where (u1 .lt. 1.E-7_real32)
185 u1 = 1.E-7_real32
186 end where
187
188 ! Box-Muller transform for normal distribution
189 z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2)
190 z = sigma * z
191
192 ! Assign according to rank
193 select rank(input)
194 rank(0)
195 input = z(1)
196 rank(1)
197 input = z
198 rank(2)
199 input = reshape(z, shape(input))
200 rank(3)
201 input = reshape(z, shape(input))
202 rank(4)
203 input = reshape(z, shape(input))
204 rank(5)
205 input = reshape(z, shape(input))
206 rank(6)
207 input = reshape(z, shape(input))
208 end select
209
210 deallocate(u1, u2, z)
211
212 end subroutine glorot_normal_initialise
213 !###############################################################################
214
215 end module athena__initialiser_glorot
216