GCC Code Coverage Report


Directory: src/athena/
File: athena_initialiser_he.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__initialiser_he
2 !! Module containing the implementation of the He initialiser
3 !!
4 !! This module implements He (Kaiming/MSRA) initialisation, designed for
5 !! layers with ReLU activation to prevent vanishing/exploding gradients.
6 !!
7 !! Mathematical operation:
8 !!
9 !! Uniform variant:
10 !! \[ W \sim \mathcal{U}(-\text{limit}, \text{limit}), \quad \text{limit} = \sqrt{\frac{6}{n_{\text{in}}}} \]
11 !!
12 !! Normal variant:
13 !! \[ W \sim \mathcal{N}(0, \sigma^2), \quad \sigma = \sqrt{\frac{2}{n_{\text{in}}}} \]
14 !!
15 !! where \(n_{\text{in}}\) is the number of input units (fan-in).
16 !!
17 !! Rationale: Maintains variance through ReLU layers
18 !! \(\text{Var}(\text{output}) \approx \text{Var}(\text{input})\)
19 !!
20 !! Best for: ReLU, Leaky ReLU, PReLU activations
21 !! Reference: He et al. (2015), ICCV, arXiv:1502.01852
22 use coreutils, only: real32, pi, to_lower, stop_program
23 use athena__misc_types, only: base_init_type
24 implicit none
25
26
27 private
28
29 public :: he_uniform_init_type, he_normal_init_type
30
31
32 type, extends(base_init_type) :: he_uniform_init_type
33 !! Type for the He initialiser (uniform)
34 integer, private :: mode = 1
35 contains
36 procedure, pass(this) :: initialise => he_uniform_initialise
37 !! Initialise the weights and biases using the He uniform distribution
38 end type he_uniform_init_type
39
40 type, extends(base_init_type) :: he_normal_init_type
41 !! Type for the He initialiser (normal)
42 integer, private :: mode = 1
43 contains
44 procedure, pass(this) :: initialise => he_normal_initialise
45 !! Initialise the weights and biases using the He normal distribution
46 end type he_normal_init_type
47
48
49 interface he_uniform_init_type
50 module function initialiser_uniform_setup(scale, mode) result(initialiser)
51 !! Interface for the He uniform initialiser
52 real(real32), intent(in), optional :: scale
53 !! Scaling factor (default: 1.0)
54 character(len=*), intent(in), optional :: mode
55 !! Mode for calculating the scaling factor (default: "fan_in")
56 type(he_uniform_init_type) :: initialiser
57 !! He uniform initialiser object
58 end function initialiser_uniform_setup
59 end interface he_uniform_init_type
60
61 interface he_normal_init_type
62 module function initialiser_normal_setup(scale, mode) result(initialiser)
63 !! Interface for the He normal initialiser
64 real(real32), intent(in), optional :: scale
65 !! Scaling factor (default: 1.0)
66 character(len=*), intent(in), optional :: mode
67 !! Mode for calculating the scaling factor (default: "fan_in")
68 type(he_normal_init_type) :: initialiser
69 !! He normal initialiser object
70 end function initialiser_normal_setup
71 end interface he_normal_init_type
72
73
74
75 contains
76
77 !###############################################################################
78 module function initialiser_uniform_setup(scale, mode) result(initialiser)
79 implicit none
80 ! Arguments
81 real(real32), intent(in), optional :: scale
82 !! Scaling factor (default: 1.0)
83 character(len=*), intent(in), optional :: mode
84 !! Mode for calculating the scaling factor (default: "fan_in")
85 type(he_uniform_init_type) :: initialiser
86 !! He uniform initialiser object
87
88 ! Local variables
89 character(len=20) :: mode_
90 !! Mode for calculating the scaling factor
91
92 initialiser%name = "he_uniform"
93 if(present(scale)) initialiser%scale = scale
94 if(present(mode))then
95 mode_ = to_lower(trim(mode))
96 select case(mode_)
97 case("fan_in")
98 initialiser%mode = 1
99 case("fan_out")
100 initialiser%mode = 2
101 case default
102 call stop_program("initialiser_setup: invalid mode")
103 end select
104 end if
105
106 end function initialiser_uniform_setup
107 !-------------------------------------------------------------------------------
108 module function initialiser_normal_setup(scale, mode) result(initialiser)
109 implicit none
110 ! Arguments
111 real(real32), intent(in), optional :: scale
112 !! Scaling factor (default: 1.0)
113 character(len=*), intent(in), optional :: mode
114 !! Mode for calculating the scaling factor (default: "fan_in")
115 type(he_normal_init_type) :: initialiser
116 !! He normal initialiser object
117
118 ! Local variables
119 character(len=20) :: mode_
120 !! Mode for calculating the scaling factor
121
122 initialiser%name = "he_normal"
123 if(present(scale)) initialiser%scale = scale
124 if(present(mode))then
125 mode_ = to_lower(trim(mode))
126 select case(mode_)
127 case("fan_in")
128 initialiser%mode = 1
129 case("fan_out")
130 initialiser%mode = 2
131 case default
132 call stop_program("initialiser_setup: invalid mode")
133 end select
134 end if
135 end function initialiser_normal_setup
136 !###############################################################################
137
138
139 !###############################################################################
140 subroutine he_uniform_initialise(this, input, fan_in, fan_out, spacing)
141 !! Initialise the weights and biases using the He uniform distribution
142 implicit none
143
144 ! Arguments
145 class(he_uniform_init_type), intent(inout) :: this
146 !! Instance of the Glorot initialiser
147 real(real32), dimension(..), intent(out) :: input
148 !! Weights and biases to initialise
149 integer, optional, intent(in) :: fan_in, fan_out
150 !! Number of input and output units
151 integer, dimension(:), optional, intent(in) :: spacing
152 !! Spacing of the input and output units (not used)
153
154 ! Local variables
155 integer :: n
156 !! Number of elements in the input array
157 real(real32) :: limit
158 !! Scaling factor
159 real(real32), dimension(:), allocatable :: r
160 !! Temporary uniform random numbers
161
162 if(.not.present(fan_in)) &
163 call stop_program("he_uniform_initialise: fan_in not present")
164
165 select case(this%mode)
166 case(1)
167 limit = this%scale * sqrt(6._real32 / real(fan_in, real32))
168 case(2)
169 limit = this%scale * sqrt(6._real32 / real(fan_out, real32))
170 case default
171 call stop_program("he_uniform_initialise: invalid mode")
172 end select
173 n = size(input)
174 allocate(r(n))
175 call random_number(r)
176 r = (2._real32 * r - 1._real32) * limit
177
178 ! Assign according to rank
179 select rank(input)
180 rank(0)
181 input = r(1)
182 rank(1)
183 input = r
184 rank(2)
185 input = reshape(r, shape(input))
186 rank(3)
187 input = reshape(r, shape(input))
188 rank(4)
189 input = reshape(r, shape(input))
190 rank(5)
191 input = reshape(r, shape(input))
192 rank(6)
193 input = reshape(r, shape(input))
194 end select
195
196 deallocate(r)
197 end subroutine he_uniform_initialise
198 !###############################################################################
199
200
201 !###############################################################################
202 subroutine he_normal_initialise(this, input, fan_in, fan_out, spacing)
203 !! Initialise the weights and biases using the He normal distribution
204 implicit none
205
206 ! Arguments
207 class(he_normal_init_type), intent(inout) :: this
208 !! Instance of the He initialiser
209 real(real32), dimension(..), intent(out) :: input
210 !! Weights and biases to initialise
211 integer, optional, intent(in) :: fan_in, fan_out
212 !! Number of input and output parameters
213 integer, dimension(:), optional, intent(in) :: spacing
214 !! Spacing of the input and output units (not used)
215
216 ! Local variables
217 integer :: n
218 !! Number of elements in the input array
219 real(real32) :: sigma
220 !! Scaling factor
221 real(real32), dimension(:), allocatable :: u1, u2, z
222 !! Temporary arrays for the random numbers
223
224 if(.not.present(fan_in)) &
225 call stop_program("he_normal_initialise: fan_in not present")
226
227 select case(this%mode)
228 case(1)
229 sigma = this%scale * sqrt(2._real32/real(fan_in,real32))
230 case(2)
231 sigma = this%scale * sqrt(2._real32/real(fan_out,real32))
232 case default
233 call stop_program("he_uniform_initialise: invalid mode")
234 end select
235 n = size(input)
236 allocate(u1(n), u2(n), z(n))
237
238 call random_number(u1)
239 call random_number(u2)
240 where (u1 .lt. 1.E-7_real32)
241 u1 = 1.E-7_real32
242 end where
243
244 ! Box-Muller transform
245 z = sqrt(-2._real32 * log(u1)) * cos(2._real32 * pi * u2)
246 z = sigma * z
247
248 select rank(input)
249 rank(0)
250 input = z(1)
251 rank(1)
252 input = z
253 rank(2)
254 input = reshape(z, shape(input))
255 rank(3)
256 input = reshape(z, shape(input))
257 rank(4)
258 input = reshape(z, shape(input))
259 rank(5)
260 input = reshape(z, shape(input))
261 rank(6)
262 input = reshape(z, shape(input))
263 end select
264
265 deallocate(u1, u2, z)
266
267 end subroutine he_normal_initialise
268 !###############################################################################
269
270 end module athena__initialiser_he
271