GCC Code Coverage Report


Directory: src/lib/
File: src/lib/mod_optimiser.f90
Date: 2024-06-28 12:51:18
Exec Total Coverage
Lines: 160 179 89.4%
Functions: 0 0 -%
Branches: 454 943 48.1%

Line Branch Exec Source
1 !!!#############################################################################
2 !!! Code written by Ned Thaddeus Taylor
3 !!! Code part of the ATHENA library - a feedforward neural network library
4 !!!#############################################################################
5 !!! module contains implementations of optimisation methods
6 !!! module contains the following derived types:
7 !!! - base_optimiser_type - base optimiser type
8 !!! - sgd_optimiser_type - stochastic gradient descent optimiser type
9 !!! - rmsprop_optimiser_type - rmsprop optimiser type
10 !!! - adagrad_optimiser_type - adagrad optimiser type
11 !!! - adam_optimiser_type - adam optimiser type
12 !!!##################
13 !!! <NAME>_optimiser_type contains the following procedures:
14 !!! - init_gradients - initialise gradients
15 !!! - minimise - minimise the loss function by applying gradients to ...
16 !!! ... the parameters
17 !!!#############################################################################
18 !!! Attribution statement:
19 !!! The following module is based on code from the neural-fortran library
20 !!! https://github.com/modern-fortran/neural-fortran/blob/main/src/nf/nf_optimizers.f90
21 !!! The implementation of optimiser_base_type is based on the ...
22 !!! ... optimizer_base_type from the neural-fortran library
23 !!! The same applies to the implementation of the sgd_optimiser_type, ...
24 !!! ... rmsprop_optimiser_type, adagrad_optimiser_type, and adam_optimiser_type
25 !!!#############################################################################
26 module optimiser
27 use constants, only: real12
28 use clipper, only: clip_type
29 use regulariser, only: &
30 base_regulariser_type, &
31 l2_regulariser_type
32 use learning_rate_decay, only: base_lr_decay_type
33 implicit none
34
35 !!!-----------------------------------------------------------------------------
36 !!! learning parameter type
37 !!!-----------------------------------------------------------------------------
38 type :: base_optimiser_type !!base_optimiser_type
39 !! iter = iteration number
40 !! learning_rate = learning rate hyperparameter
41 !! regularisation = apply regularisation
42 !! regulariser = regularisation method
43 !! clip_dict = clipping dictionary
44 integer :: iter = 0
45 real(real12) :: learning_rate = 0.01_real12
46 logical :: regularisation = .false.
47 class(base_regulariser_type), allocatable :: regulariser
48 class(base_lr_decay_type), allocatable :: lr_decay
49 type(clip_type) :: clip_dict
50 contains
51 procedure, pass(this) :: init => init_base
52 procedure, pass(this) :: init_gradients => init_gradients_base
53 procedure, pass(this) :: minimise => minimise_base
54 end type base_optimiser_type
55
56 interface base_optimiser_type
57 module function optimiser_setup_base( &
58 learning_rate, &
59 num_params, &
60 regulariser, clip_dict, lr_decay) result(optimiser)
61 real(real12), optional, intent(in) :: learning_rate
62 integer, optional, intent(in) :: num_params
63 class(base_regulariser_type), optional, intent(in) :: regulariser
64 type(clip_type), optional, intent(in) :: clip_dict
65 class(base_lr_decay_type), optional, intent(in) :: lr_decay
66 type(base_optimiser_type) :: optimiser
67 end function optimiser_setup_base
68 end interface base_optimiser_type
69
70
71 !!!-----------------------------------------------------------------------------
72
73 type, extends(base_optimiser_type) :: sgd_optimiser_type
74 logical :: nesterov = .false.
75 real(real12) :: momentum = 0._real12 ! fraction of momentum based learning
76 real(real12), allocatable, dimension(:) :: velocity
77 contains
78 procedure, pass(this) :: init_gradients => init_gradients_sgd
79 procedure, pass(this) :: minimise => minimise_sgd
80 end type sgd_optimiser_type
81
82 interface sgd_optimiser_type
83 module function optimiser_setup_sgd( &
84 learning_rate, momentum, &
85 nesterov, num_params, &
86 regulariser, clip_dict, lr_decay) result(optimiser)
87 real(real12), optional, intent(in) :: learning_rate, momentum
88 logical, optional, intent(in) :: nesterov
89 integer, optional, intent(in) :: num_params
90 class(base_regulariser_type), optional, intent(in) :: regulariser
91 type(clip_type), optional, intent(in) :: clip_dict
92 class(base_lr_decay_type), optional, intent(in) :: lr_decay
93 type(sgd_optimiser_type) :: optimiser
94 end function optimiser_setup_sgd
95 end interface sgd_optimiser_type
96
97 !!!-----------------------------------------------------------------------------
98
99 type, extends(base_optimiser_type) :: rmsprop_optimiser_type
100 real(real12) :: beta = 0._real12
101 real(real12) :: epsilon = 1.E-8_real12
102 real(real12), allocatable, dimension(:) :: moving_avg
103 contains
104 procedure, pass(this) :: init_gradients => init_gradients_rmsprop
105 procedure, pass(this) :: minimise => minimise_rmsprop
106 end type rmsprop_optimiser_type
107
108 interface rmsprop_optimiser_type
109 module function optimiser_setup_rmsprop( &
110 learning_rate, beta, &
111 epsilon, num_params, &
112 regulariser, clip_dict, lr_decay) result(optimiser)
113 real(real12), optional, intent(in) :: learning_rate, beta, epsilon
114 integer, optional, intent(in) :: num_params
115 class(base_regulariser_type), optional, intent(in) :: regulariser
116 type(clip_type), optional, intent(in) :: clip_dict
117 class(base_lr_decay_type), optional, intent(in) :: lr_decay
118 type(rmsprop_optimiser_type) :: optimiser
119 end function optimiser_setup_rmsprop
120 end interface rmsprop_optimiser_type
121
122 !!!-----------------------------------------------------------------------------
123
124 type, extends(base_optimiser_type) :: adagrad_optimiser_type
125 real(real12) :: epsilon = 1.E-8_real12
126 real(real12), allocatable, dimension(:) :: sum_squares
127 contains
128 procedure, pass(this) :: init_gradients => init_gradients_adagrad
129 procedure, pass(this) :: minimise => minimise_adagrad
130 end type adagrad_optimiser_type
131
132 interface adagrad_optimiser_type
133 module function optimiser_setup_adagrad( &
134 learning_rate, &
135 epsilon, num_params, &
136 regulariser, clip_dict, lr_decay) result(optimiser)
137 real(real12), optional, intent(in) :: learning_rate, epsilon
138 integer, optional, intent(in) :: num_params
139 class(base_regulariser_type), optional, intent(in) :: regulariser
140 type(clip_type), optional, intent(in) :: clip_dict
141 class(base_lr_decay_type), optional, intent(in) :: lr_decay
142 type(adagrad_optimiser_type) :: optimiser
143 end function optimiser_setup_adagrad
144 end interface adagrad_optimiser_type
145
146 !!!-----------------------------------------------------------------------------
147
148 type, extends(base_optimiser_type) :: adam_optimiser_type
149 real(real12) :: beta1 = 0.9_real12
150 real(real12) :: beta2 = 0.999_real12
151 real(real12) :: epsilon = 1.E-8_real12
152 real(real12), allocatable, dimension(:) :: m
153 real(real12), allocatable, dimension(:) :: v
154 contains
155 procedure, pass(this) :: init_gradients => init_gradients_adam
156 procedure, pass(this) :: minimise => minimise_adam
157 end type adam_optimiser_type
158
159 interface adam_optimiser_type
160 module function optimiser_setup_adam( &
161 learning_rate, &
162 beta1, beta2, epsilon, &
163 num_params, &
164 regulariser, clip_dict, lr_decay) result(optimiser)
165 real(real12), optional, intent(in) :: learning_rate
166 real(real12), optional, intent(in) :: beta1, beta2, epsilon
167 integer, optional, intent(in) :: num_params
168 class(base_regulariser_type), optional, intent(in) :: regulariser
169 type(clip_type), optional, intent(in) :: clip_dict
170 class(base_lr_decay_type), optional, intent(in) :: lr_decay
171 type(adam_optimiser_type) :: optimiser
172 end function optimiser_setup_adam
173 end interface adam_optimiser_type
174
175 !! reduce learning rate on plateau parameters
176 !integer :: wait = 0
177 !integer :: patience = 0
178 !real(real12) :: factor = 0._real12
179 !real(real12) :: min_learning_rate = 0._real12
180
181
182 private
183
184 public :: base_optimiser_type
185 public :: sgd_optimiser_type
186 public :: rmsprop_optimiser_type
187 public :: adagrad_optimiser_type
188 public :: adam_optimiser_type
189
190
191 contains
192
193 !!!#############################################################################
194 !!! set up optimiser
195 !!!#############################################################################
196 7 module function optimiser_setup_base( &
197 learning_rate, &
198 num_params, &
199 regulariser, clip_dict, lr_decay) result(optimiser)
200 implicit none
201 real(real12), optional, intent(in) :: learning_rate
202 integer, optional, intent(in) :: num_params
203 class(base_regulariser_type), optional, intent(in) :: regulariser
204 type(clip_type), optional, intent(in) :: clip_dict
205 class(base_lr_decay_type), optional, intent(in) :: lr_decay
206
207 type(base_optimiser_type) :: optimiser
208
209 integer :: num_params_
210
211
212 !! apply regularisation
213
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(regulariser))then
214 1 optimiser%regularisation = .true.
215
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
216
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%regulariser, source = regulariser)
217 end if
218
219 !! apply clipping
220
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(present(clip_dict)) optimiser%clip_dict = clip_dict
221
222 !! initialise general optimiser parameters
223
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(present(learning_rate)) optimiser%learning_rate = learning_rate
224
225 !! initialise learning rate decay
226
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
7 if(present(lr_decay)) then
227
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
228
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%lr_decay, source = lr_decay)
229 else
230
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 allocate(optimiser%lr_decay, source = base_lr_decay_type())
231 end if
232
233 !! initialise gradients
234
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
7 if(present(num_params)) then
235 1 num_params_ = num_params
236 else
237 6 num_params_ = 1
238 end if
239 7 call optimiser%init_gradients(num_params_)
240
241
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
14 end function optimiser_setup_base
242 !!!#############################################################################
243
244
245 !!!#############################################################################
246 !!! initialise optimiser
247 !!!#############################################################################
248 7 subroutine init_base(this, num_params, regulariser, clip_dict)
249 implicit none
250 class(base_optimiser_type), intent(inout) :: this
251 integer, intent(in) :: num_params
252 class(base_regulariser_type), optional, intent(in) :: regulariser
253 type(clip_type), optional, intent(in) :: clip_dict
254
255
256 !! apply regularisation
257
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
7 if(present(regulariser))then
258 this%regularisation = .true.
259 if(allocated(this%regulariser)) deallocate(this%regulariser)
260 allocate(this%regulariser, source = regulariser)
261 end if
262
263 !! apply clipping
264
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(present(clip_dict)) then
265 this%clip_dict = clip_dict
266 end if
267
268 !! initialise gradients
269 7 call this%init_gradients(num_params)
270
271 7 end subroutine init_base
272 !!!#############################################################################
273
274
275 !!!#############################################################################
276 !!! initialise gradients
277 !!!#############################################################################
278 12 pure subroutine init_gradients_base(this, num_params)
279 implicit none
280 class(base_optimiser_type), intent(inout) :: this
281 integer, intent(in) :: num_params
282
283 !allocate(this%velocity(num_params), source=0._real12)
284 12 return
285 end subroutine init_gradients_base
286 !!!#############################################################################
287
288
289 !!!#############################################################################
290 !!! minimise the loss function by applying gradients to the parameters
291 !!!#############################################################################
292
2/4
✓ Branch 0 taken 500 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 500 times.
✗ Branch 3 not taken.
500 pure subroutine minimise_base(this, param, gradient)
293 implicit none
294 class(base_optimiser_type), intent(inout) :: this
295 real(real12), dimension(:), intent(inout) :: param
296 real(real12), dimension(:), intent(inout) :: gradient
297
298 real(real12) :: learning_rate
299
300
301 !! decay learning rate and update iteration
302 500 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
303 500 this%iter = this%iter + 1
304
305 !! update parameters
306
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 500 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 500 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 500 times.
✓ Branch 33 taken 15978 times.
✓ Branch 34 taken 500 times.
16478 param = param - learning_rate * gradient
307
308 500 end subroutine minimise_base
309 !!!#############################################################################
310
311
312 !!!##########################################################################!!!
313 !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!!
314 !!!##########################################################################!!!
315
316
317 !!!#############################################################################
318 !!! set up optimiser
319 !!!#############################################################################
320 3 module function optimiser_setup_sgd( &
321 learning_rate, momentum, &
322 nesterov, num_params, &
323 regulariser, clip_dict, lr_decay) result(optimiser)
324 implicit none
325 real(real12), optional, intent(in) :: learning_rate, momentum
326 logical, optional, intent(in) :: nesterov
327 integer, optional, intent(in) :: num_params
328 class(base_regulariser_type), optional, intent(in) :: regulariser
329 type(clip_type), optional, intent(in) :: clip_dict
330 class(base_lr_decay_type), optional, intent(in) :: lr_decay
331
332 type(sgd_optimiser_type) :: optimiser
333
334 integer :: num_params_
335
336
337 !! apply regularisation
338
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(regulariser))then
339 1 optimiser%regularisation = .true.
340
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
341
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%regulariser, source = regulariser)
342 end if
343
344 !! apply clipping
345
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if(present(clip_dict)) optimiser%clip_dict = clip_dict
346
347 !! initialise general optimiser parameters
348
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(learning_rate)) optimiser%learning_rate = learning_rate
349
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
3 if(present(momentum)) optimiser%momentum = momentum
350
351 !! initialise learning rate decay
352
3/4
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
3 if(present(lr_decay)) then
353
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
354
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%lr_decay, source = lr_decay)
355 else
356
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 allocate(optimiser%lr_decay, source = base_lr_decay_type())
357 end if
358
359 !! initialise nesterov boolean
360
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
3 if(present(nesterov)) optimiser%nesterov = nesterov
361
362 !! initialise gradients
363
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
3 if(present(num_params)) then
364 1 num_params_ = num_params
365 else
366 2 num_params_ = 1
367 end if
368 3 call optimiser%init_gradients(num_params_)
369
370
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
6 end function optimiser_setup_sgd
371 !!!#############################################################################
372
373
374 !!!#############################################################################
375 !!! initialise gradients
376 !!!#############################################################################
377 5 pure subroutine init_gradients_sgd(this, num_params)
378 implicit none
379 class(sgd_optimiser_type), intent(inout) :: this
380 integer, intent(in) :: num_params
381
382
383 !! initialise gradients
384
3/4
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
5 if(allocated(this%velocity)) deallocate(this%velocity)
385
13/24
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✓ Branch 29 taken 23674 times.
✓ Branch 30 taken 5 times.
23679 allocate(this%velocity(num_params), source=0._real12)
386
387 5 end subroutine init_gradients_sgd
388 !!!#############################################################################
389
390
391 !!!#############################################################################
392 !!! minimise the loss function by applying gradients to the parameters
393 !!!#############################################################################
394
2/4
✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
9 pure subroutine minimise_sgd(this, param, gradient)
395 implicit none
396 class(sgd_optimiser_type), intent(inout) :: this
397 real(real12), dimension(:), intent(inout) :: param
398 real(real12), dimension(:), intent(inout) :: gradient
399
400 real(real12) :: learning_rate
401
402
403 !! decay learning rate and update iteration
404 9 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
405 9 this%iter = this%iter + 1
406
407 !! apply regularisation
408
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
9 if(this%regularisation) &
409 call this%regulariser%regularise( &
410
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 param, gradient, learning_rate)
411
412
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
9 if(this%momentum.gt.1.E-8_real12)then !! adaptive learning method
413 4 this%velocity = this%momentum * this%velocity - &
414
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
11 learning_rate * gradient
415 else !! standard learning method
416
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✓ Branch 9 taken 8 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 8 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 16 times.
✓ Branch 16 taken 8 times.
24 this%velocity = - learning_rate * gradient
417 end if
418
419 !! update parameters
420
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
9 if(this%nesterov)then
421 param = param + this%momentum * this%velocity - &
422
18/34
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 41 taken 10 times.
✓ Branch 42 taken 1 times.
11 learning_rate * gradient
423 else
424
14/26
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 8 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
✓ Branch 36 taken 16 times.
✓ Branch 37 taken 8 times.
24 param = param + this%velocity
425 end if
426
427 9 end subroutine minimise_sgd
428 !!!#############################################################################
429
430
431 !!!##########################################################################!!!
432 !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!!
433 !!!##########################################################################!!!
434
435
436 !!!#############################################################################
437 !!! set up optimiser
438 !!!#############################################################################
439 1 module function optimiser_setup_rmsprop( &
440 learning_rate, &
441 beta, epsilon, &
442 num_params, &
443 regulariser, clip_dict, lr_decay) result(optimiser)
444 implicit none
445 real(real12), optional, intent(in) :: learning_rate
446 real(real12), optional, intent(in) :: beta, epsilon
447 integer, optional, intent(in) :: num_params
448 class(base_regulariser_type), optional, intent(in) :: regulariser
449 type(clip_type), optional, intent(in) :: clip_dict
450 class(base_lr_decay_type), optional, intent(in) :: lr_decay
451
452 type(rmsprop_optimiser_type) :: optimiser
453
454 integer :: num_params_
455
456
457 !! apply regularisation
458
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(regulariser))then
459 1 optimiser%regularisation = .true.
460
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
461
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%regulariser, source = regulariser)
462 end if
463
464 !! apply clipping
465
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(clip_dict)) optimiser%clip_dict = clip_dict
466
467 !! initialise general optimiser parameters
468
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(learning_rate)) optimiser%learning_rate = learning_rate
469
470 !! initialise learning rate decay
471
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 if(present(lr_decay)) then
472
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
473
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%lr_decay, source = lr_decay)
474 else
475 allocate(optimiser%lr_decay, source = base_lr_decay_type())
476 end if
477
478 !! initialise adam parameters
479
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(beta)) optimiser%beta = beta
480
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(epsilon)) optimiser%epsilon = epsilon
481
482 !! initialise gradients
483
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(num_params)) then
484 1 num_params_ = num_params
485 else
486 num_params_ = 1
487 end if
488 1 call optimiser%init_gradients(num_params_)
489
490
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
2 end function optimiser_setup_rmsprop
491 !!!#############################################################################
492
493
494 !!!#############################################################################
495 !!! initialise gradients
496 !!!#############################################################################
497 1 pure subroutine init_gradients_rmsprop(this, num_params)
498 implicit none
499 class(rmsprop_optimiser_type), intent(inout) :: this
500 integer, intent(in) :: num_params
501
502
503 !! initialise gradients
504
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1 if(allocated(this%moving_avg)) deallocate(this%moving_avg)
505
13/24
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✓ Branch 29 taken 10 times.
✓ Branch 30 taken 1 times.
11 allocate(this%moving_avg(num_params), source=0._real12) !1.E-8_real12)
506
507 1 end subroutine init_gradients_rmsprop
508 !!!#############################################################################
509
510
511 !!!#############################################################################
512 !!! minimise the loss function by applying gradients to the parameters
513 !!!#############################################################################
514
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine minimise_rmsprop(this, param, gradient)
515 implicit none
516 class(rmsprop_optimiser_type), intent(inout) :: this
517 real(real12), dimension(:), intent(inout) :: param
518 real(real12), dimension(:), intent(inout) :: gradient
519
520 real(real12) :: learning_rate
521
522
523 !! decay learning rate and update iteration
524 1 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
525 1 this%iter = this%iter + 1
526
527 !! apply regularisation
528
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%regularisation) &
529 call this%regulariser%regularise( &
530
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
1 param, gradient, learning_rate)
531
532 4 this%moving_avg = this%beta * this%moving_avg + &
533
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
11 (1._real12 - this%beta) * gradient ** 2._real12
534
535 !! update parameters
536 param = param - learning_rate * gradient / &
537
18/34
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✓ Branch 37 taken 10 times.
✓ Branch 38 taken 1 times.
11 (sqrt(this%moving_avg + this%epsilon))
538
539 1 end subroutine minimise_rmsprop
540 !!!#############################################################################
541
542
543 !!!##########################################################################!!!
544 !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!!
545 !!!##########################################################################!!!
546
547
548 !!!#############################################################################
549 !!! set up optimiser
550 !!!#############################################################################
551 1 module function optimiser_setup_adagrad( &
552 learning_rate, &
553 epsilon, &
554 num_params, &
555 regulariser, clip_dict, lr_decay) result(optimiser)
556 implicit none
557 real(real12), optional, intent(in) :: learning_rate
558 real(real12), optional, intent(in) :: epsilon
559 integer, optional, intent(in) :: num_params
560 class(base_regulariser_type), optional, intent(in) :: regulariser
561 type(clip_type), optional, intent(in) :: clip_dict
562 class(base_lr_decay_type), optional, intent(in) :: lr_decay
563
564 type(adagrad_optimiser_type) :: optimiser
565
566 integer :: num_params_
567
568
569 !! apply regularisation
570
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(regulariser))then
571 1 optimiser%regularisation = .true.
572
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
573
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%regulariser, source = regulariser)
574 end if
575
576 !! apply clipping
577
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(clip_dict)) optimiser%clip_dict = clip_dict
578
579 !! initialise general optimiser parameters
580
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(learning_rate)) optimiser%learning_rate = learning_rate
581
582 !! initialise learning rate decay
583
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 if(present(lr_decay)) then
584
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
1 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
585
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(optimiser%lr_decay, source = lr_decay)
586 else
587 allocate(optimiser%lr_decay, source = base_lr_decay_type())
588 end if
589
590 !! initialise adam parameters
591
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(epsilon)) optimiser%epsilon = epsilon
592
593 !! initialise gradients
594
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(num_params)) then
595 1 num_params_ = num_params
596 else
597 num_params_ = 1
598 end if
599 1 call optimiser%init_gradients(num_params_)
600
601
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
2 end function optimiser_setup_adagrad
602 !!!#############################################################################
603
604
605 !!!#############################################################################
606 !!! initialise gradients
607 !!!#############################################################################
608 1 pure subroutine init_gradients_adagrad(this, num_params)
609 implicit none
610 class(adagrad_optimiser_type), intent(inout) :: this
611 integer, intent(in) :: num_params
612
613
614 !! initialise gradients
615
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
1 if(allocated(this%sum_squares)) deallocate(this%sum_squares)
616
13/24
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✓ Branch 29 taken 10 times.
✓ Branch 30 taken 1 times.
11 allocate(this%sum_squares(num_params), source=0._real12) !1.E-8_real12)
617
618 1 end subroutine init_gradients_adagrad
619 !!!#############################################################################
620
621
622 !!!#############################################################################
623 !!! minimise the loss function by applying gradients to the parameters
624 !!!#############################################################################
625
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine minimise_adagrad(this, param, gradient)
626 implicit none
627 class(adagrad_optimiser_type), intent(inout) :: this
628 real(real12), dimension(:), intent(inout) :: param
629 real(real12), dimension(:), intent(inout) :: gradient
630
631 real(real12) :: learning_rate
632
633
634 !! decay learning rate and update iteration
635 1 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
636 1 this%iter = this%iter + 1
637
638 !! apply regularisation
639
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%regularisation) &
640 call this%regulariser%regularise( &
641
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
1 param, this%sum_squares, learning_rate)
642
643
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
11 this%sum_squares = this%sum_squares + gradient ** 2._real12
644
645 !! update parameters
646 param = param - learning_rate * gradient / &
647
18/34
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✓ Branch 37 taken 10 times.
✓ Branch 38 taken 1 times.
11 (sqrt(this%sum_squares + this%epsilon))
648
649 1 end subroutine minimise_adagrad
650 !!!#############################################################################
651
652
653 !!!##########################################################################!!!
654 !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!!
655 !!!##########################################################################!!!
656
657
658 !!!#############################################################################
659 !!! set up optimiser
660 !!!#############################################################################
661 3 module function optimiser_setup_adam( &
662 learning_rate, &
663 beta1, beta2, epsilon, &
664 num_params, &
665 regulariser, clip_dict, lr_decay) result(optimiser)
666 implicit none
667 real(real12), optional, intent(in) :: learning_rate
668 real(real12), optional, intent(in) :: beta1, beta2, epsilon
669 integer, optional, intent(in) :: num_params
670 class(base_regulariser_type), optional, intent(in) :: regulariser
671 type(clip_type), optional, intent(in) :: clip_dict
672 class(base_lr_decay_type), optional, intent(in) :: lr_decay
673
674 type(adam_optimiser_type) :: optimiser
675
676 integer :: num_params_
677
678
679 !! apply regularisation
680
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(regulariser))then
681 3 optimiser%regularisation = .true.
682
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
3 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
683
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 allocate(optimiser%regulariser, source = regulariser)
684 end if
685
686 !! apply clipping
687
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if(present(clip_dict)) optimiser%clip_dict = clip_dict
688
689 !! initialise general optimiser parameters
690
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(learning_rate)) optimiser%learning_rate = learning_rate
691
692 !! initialise learning rate decay
693
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 if(present(lr_decay)) then
694
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
3 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
695
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
3 allocate(optimiser%lr_decay, source = lr_decay)
696 else
697 allocate(optimiser%lr_decay, source = base_lr_decay_type())
698 end if
699
700 !! initialise adam parameters
701
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(beta1)) optimiser%beta1 = beta1
702
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(beta2)) optimiser%beta2 = beta2
703
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(epsilon)) optimiser%epsilon = epsilon
704
705 !! initialise gradients
706
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(present(num_params)) then
707 3 num_params_ = num_params
708 else
709 num_params_ = 1
710 end if
711 3 call optimiser%init_gradients(num_params_)
712
713
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
6 end function optimiser_setup_adam
714 !!!#############################################################################
715
716
717 !!!#############################################################################
718 !!! initialise gradients
719 !!!#############################################################################
720 3 pure subroutine init_gradients_adam(this, num_params)
721 implicit none
722 class(adam_optimiser_type), intent(inout) :: this
723 integer, intent(in) :: num_params
724
725
726 !! initialise gradients
727
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
3 if(allocated(this%m)) deallocate(this%m)
728
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
3 if(allocated(this%v)) deallocate(this%v)
729
13/24
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✓ Branch 29 taken 30 times.
✓ Branch 30 taken 3 times.
33 allocate(this%m(num_params), source=0._real12)
730
13/24
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✓ Branch 29 taken 30 times.
✓ Branch 30 taken 3 times.
33 allocate(this%v(num_params), source=0._real12)
731
732 3 end subroutine init_gradients_adam
733 !!!#############################################################################
734
735
736 !!!#############################################################################
737 !!! minimise the loss function by applying gradients to the parameters
738 !!!#############################################################################
739
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 pure subroutine minimise_adam(this, param, gradient)
740 implicit none
741 class(adam_optimiser_type), intent(inout) :: this
742 real(real12), dimension(:), intent(inout) :: param
743 real(real12), dimension(:), intent(inout) :: gradient
744
745 real(real12) :: learning_rate
746
747
748 !! decay learning rate and update iteration
749 3 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
750 3 this%iter = this%iter + 1
751
752 !! apply regularisation
753
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 if(this%regularisation) &
754 call this%regulariser%regularise( &
755
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
3 param, gradient, learning_rate)
756
757 !! adaptive learning method
758 12 this%m = this%beta1 * this%m + &
759
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
33 (1._real12 - this%beta1) * gradient
760 12 this%v = this%beta2 * this%v + &
761
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
33 (1._real12 - this%beta2) * gradient ** 2._real12
762
763 !! update parameters
764 associate( &
765 12 m_hat => this%m / (1._real12 - this%beta1**this%iter), &
766 12 v_hat => this%v / (1._real12 - this%beta2**this%iter) )
767
14/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✓ Branch 14 taken 30 times.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3 times.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
63 select type(regulariser => this%regulariser)
768 type is (l2_regulariser_type)
769 2 select case(regulariser%decoupled)
770 case(.true.)
771 param = param - &
772 learning_rate * &
773 ( m_hat / (sqrt(v_hat) + this%epsilon) + &
774
23/44
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✓ Branch 48 taken 10 times.
✓ Branch 49 taken 1 times.
11 regulariser%l2 * param )
775 case(.false.)
776 param = param + &
777 learning_rate * &
778 ( ( m_hat + regulariser%l2 * param ) / &
779
25/46
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✓ Branch 50 taken 10 times.
✓ Branch 51 taken 1 times.
12 (sqrt(v_hat) + this%epsilon) )
780 end select
781 class default
782 param = param + &
783 learning_rate * ( ( m_hat + param ) / &
784
23/44
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✓ Branch 48 taken 10 times.
✓ Branch 49 taken 1 times.
11 (sqrt(v_hat) + this%epsilon) )
785 end select
786 end associate
787
788 3 end subroutine minimise_adam
789 !!!#############################################################################
790
791
21/63
✓ Branch 0 taken 17 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 7 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 15 times.
✓ Branch 5 taken 15 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 15 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 15 times.
✓ Branch 45 taken 15 times.
✓ Branch 46 taken 15 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 15 times.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
✓ Branch 67 taken 15 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 3 times.
✓ Branch 70 taken 12 times.
✓ Branch 71 taken 2 times.
✓ Branch 72 taken 6 times.
✗ Branch 73 not taken.
✓ Branch 74 taken 11 times.
✗ Branch 75 not taken.
✗ Branch 76 not taken.
✓ Branch 77 taken 9 times.
✗ Branch 78 not taken.
✓ Branch 79 taken 9 times.
93 end module optimiser
792 !!!#############################################################################
793
794