Line | Branch | Exec | Source |
---|---|---|---|
1 | !!!############################################################################# | ||
2 | !!! Code written by Ned Thaddeus Taylor | ||
3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
4 | !!!############################################################################# | ||
5 | !!! module contains implementations of optimisation methods | ||
6 | !!! module contains the following derived types: | ||
7 | !!! - base_optimiser_type - base optimiser type | ||
8 | !!! - sgd_optimiser_type - stochastic gradient descent optimiser type | ||
9 | !!! - rmsprop_optimiser_type - rmsprop optimiser type | ||
10 | !!! - adagrad_optimiser_type - adagrad optimiser type | ||
11 | !!! - adam_optimiser_type - adam optimiser type | ||
12 | !!!################## | ||
13 | !!! <NAME>_optimiser_type contains the following procedures: | ||
14 | !!! - init_gradients - initialise gradients | ||
15 | !!! - minimise - minimise the loss function by applying gradients to ... | ||
16 | !!! ... the parameters | ||
17 | !!!############################################################################# | ||
18 | !!! Attribution statement: | ||
19 | !!! The following module is based on code from the neural-fortran library | ||
20 | !!! https://github.com/modern-fortran/neural-fortran/blob/main/src/nf/nf_optimizers.f90 | ||
21 | !!! The implementation of optimiser_base_type is based on the ... | ||
22 | !!! ... optimizer_base_type from the neural-fortran library | ||
23 | !!! The same applies to the implementation of the sgd_optimiser_type, ... | ||
24 | !!! ... rmsprop_optimiser_type, adagrad_optimiser_type, and adam_optimiser_type | ||
25 | !!!############################################################################# | ||
26 | module optimiser | ||
27 | use constants, only: real12 | ||
28 | use clipper, only: clip_type | ||
29 | use regulariser, only: & | ||
30 | base_regulariser_type, & | ||
31 | l2_regulariser_type | ||
32 | use learning_rate_decay, only: base_lr_decay_type | ||
33 | implicit none | ||
34 | |||
35 | !!!----------------------------------------------------------------------------- | ||
36 | !!! learning parameter type | ||
37 | !!!----------------------------------------------------------------------------- | ||
38 | type :: base_optimiser_type !!base_optimiser_type | ||
39 | !! iter = iteration number | ||
40 | !! learning_rate = learning rate hyperparameter | ||
41 | !! regularisation = apply regularisation | ||
42 | !! regulariser = regularisation method | ||
43 | !! clip_dict = clipping dictionary | ||
44 | integer :: iter = 0 | ||
45 | real(real12) :: learning_rate = 0.01_real12 | ||
46 | logical :: regularisation = .false. | ||
47 | class(base_regulariser_type), allocatable :: regulariser | ||
48 | class(base_lr_decay_type), allocatable :: lr_decay | ||
49 | type(clip_type) :: clip_dict | ||
50 | contains | ||
51 | procedure, pass(this) :: init => init_base | ||
52 | procedure, pass(this) :: init_gradients => init_gradients_base | ||
53 | procedure, pass(this) :: minimise => minimise_base | ||
54 | end type base_optimiser_type | ||
55 | |||
56 | interface base_optimiser_type | ||
57 | module function optimiser_setup_base( & | ||
58 | learning_rate, & | ||
59 | num_params, & | ||
60 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
61 | real(real12), optional, intent(in) :: learning_rate | ||
62 | integer, optional, intent(in) :: num_params | ||
63 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
64 | type(clip_type), optional, intent(in) :: clip_dict | ||
65 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
66 | type(base_optimiser_type) :: optimiser | ||
67 | end function optimiser_setup_base | ||
68 | end interface base_optimiser_type | ||
69 | |||
70 | |||
71 | !!!----------------------------------------------------------------------------- | ||
72 | |||
73 | type, extends(base_optimiser_type) :: sgd_optimiser_type | ||
74 | logical :: nesterov = .false. | ||
75 | real(real12) :: momentum = 0._real12 ! fraction of momentum based learning | ||
76 | real(real12), allocatable, dimension(:) :: velocity | ||
77 | contains | ||
78 | procedure, pass(this) :: init_gradients => init_gradients_sgd | ||
79 | procedure, pass(this) :: minimise => minimise_sgd | ||
80 | end type sgd_optimiser_type | ||
81 | |||
82 | interface sgd_optimiser_type | ||
83 | module function optimiser_setup_sgd( & | ||
84 | learning_rate, momentum, & | ||
85 | nesterov, num_params, & | ||
86 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
87 | real(real12), optional, intent(in) :: learning_rate, momentum | ||
88 | logical, optional, intent(in) :: nesterov | ||
89 | integer, optional, intent(in) :: num_params | ||
90 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
91 | type(clip_type), optional, intent(in) :: clip_dict | ||
92 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
93 | type(sgd_optimiser_type) :: optimiser | ||
94 | end function optimiser_setup_sgd | ||
95 | end interface sgd_optimiser_type | ||
96 | |||
97 | !!!----------------------------------------------------------------------------- | ||
98 | |||
99 | type, extends(base_optimiser_type) :: rmsprop_optimiser_type | ||
100 | real(real12) :: beta = 0._real12 | ||
101 | real(real12) :: epsilon = 1.E-8_real12 | ||
102 | real(real12), allocatable, dimension(:) :: moving_avg | ||
103 | contains | ||
104 | procedure, pass(this) :: init_gradients => init_gradients_rmsprop | ||
105 | procedure, pass(this) :: minimise => minimise_rmsprop | ||
106 | end type rmsprop_optimiser_type | ||
107 | |||
108 | interface rmsprop_optimiser_type | ||
109 | module function optimiser_setup_rmsprop( & | ||
110 | learning_rate, beta, & | ||
111 | epsilon, num_params, & | ||
112 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
113 | real(real12), optional, intent(in) :: learning_rate, beta, epsilon | ||
114 | integer, optional, intent(in) :: num_params | ||
115 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
116 | type(clip_type), optional, intent(in) :: clip_dict | ||
117 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
118 | type(rmsprop_optimiser_type) :: optimiser | ||
119 | end function optimiser_setup_rmsprop | ||
120 | end interface rmsprop_optimiser_type | ||
121 | |||
122 | !!!----------------------------------------------------------------------------- | ||
123 | |||
124 | type, extends(base_optimiser_type) :: adagrad_optimiser_type | ||
125 | real(real12) :: epsilon = 1.E-8_real12 | ||
126 | real(real12), allocatable, dimension(:) :: sum_squares | ||
127 | contains | ||
128 | procedure, pass(this) :: init_gradients => init_gradients_adagrad | ||
129 | procedure, pass(this) :: minimise => minimise_adagrad | ||
130 | end type adagrad_optimiser_type | ||
131 | |||
132 | interface adagrad_optimiser_type | ||
133 | module function optimiser_setup_adagrad( & | ||
134 | learning_rate, & | ||
135 | epsilon, num_params, & | ||
136 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
137 | real(real12), optional, intent(in) :: learning_rate, epsilon | ||
138 | integer, optional, intent(in) :: num_params | ||
139 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
140 | type(clip_type), optional, intent(in) :: clip_dict | ||
141 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
142 | type(adagrad_optimiser_type) :: optimiser | ||
143 | end function optimiser_setup_adagrad | ||
144 | end interface adagrad_optimiser_type | ||
145 | |||
146 | !!!----------------------------------------------------------------------------- | ||
147 | |||
148 | type, extends(base_optimiser_type) :: adam_optimiser_type | ||
149 | real(real12) :: beta1 = 0.9_real12 | ||
150 | real(real12) :: beta2 = 0.999_real12 | ||
151 | real(real12) :: epsilon = 1.E-8_real12 | ||
152 | real(real12), allocatable, dimension(:) :: m | ||
153 | real(real12), allocatable, dimension(:) :: v | ||
154 | contains | ||
155 | procedure, pass(this) :: init_gradients => init_gradients_adam | ||
156 | procedure, pass(this) :: minimise => minimise_adam | ||
157 | end type adam_optimiser_type | ||
158 | |||
159 | interface adam_optimiser_type | ||
160 | module function optimiser_setup_adam( & | ||
161 | learning_rate, & | ||
162 | beta1, beta2, epsilon, & | ||
163 | num_params, & | ||
164 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
165 | real(real12), optional, intent(in) :: learning_rate | ||
166 | real(real12), optional, intent(in) :: beta1, beta2, epsilon | ||
167 | integer, optional, intent(in) :: num_params | ||
168 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
169 | type(clip_type), optional, intent(in) :: clip_dict | ||
170 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
171 | type(adam_optimiser_type) :: optimiser | ||
172 | end function optimiser_setup_adam | ||
173 | end interface adam_optimiser_type | ||
174 | |||
175 | !! reduce learning rate on plateau parameters | ||
176 | !integer :: wait = 0 | ||
177 | !integer :: patience = 0 | ||
178 | !real(real12) :: factor = 0._real12 | ||
179 | !real(real12) :: min_learning_rate = 0._real12 | ||
180 | |||
181 | |||
182 | private | ||
183 | |||
184 | public :: base_optimiser_type | ||
185 | public :: sgd_optimiser_type | ||
186 | public :: rmsprop_optimiser_type | ||
187 | public :: adagrad_optimiser_type | ||
188 | public :: adam_optimiser_type | ||
189 | |||
190 | |||
191 | contains | ||
192 | |||
193 | !!!############################################################################# | ||
194 | !!! set up optimiser | ||
195 | !!!############################################################################# | ||
196 | 7 | module function optimiser_setup_base( & | |
197 | learning_rate, & | ||
198 | num_params, & | ||
199 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
200 | implicit none | ||
201 | real(real12), optional, intent(in) :: learning_rate | ||
202 | integer, optional, intent(in) :: num_params | ||
203 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
204 | type(clip_type), optional, intent(in) :: clip_dict | ||
205 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
206 | |||
207 | type(base_optimiser_type) :: optimiser | ||
208 | |||
209 | integer :: num_params_ | ||
210 | |||
211 | |||
212 | !! apply regularisation | ||
213 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(regulariser))then |
214 | 1 | optimiser%regularisation = .true. | |
215 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) |
216 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%regulariser, source = regulariser) |
217 | end if | ||
218 | |||
219 | !! apply clipping | ||
220 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
|
7 | if(present(clip_dict)) optimiser%clip_dict = clip_dict |
221 | |||
222 | !! initialise general optimiser parameters | ||
223 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
|
7 | if(present(learning_rate)) optimiser%learning_rate = learning_rate |
224 | |||
225 | !! initialise learning rate decay | ||
226 |
3/4✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
7 | if(present(lr_decay)) then |
227 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) |
228 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%lr_decay, source = lr_decay) |
229 | else | ||
230 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
|
6 | allocate(optimiser%lr_decay, source = base_lr_decay_type()) |
231 | end if | ||
232 | |||
233 | !! initialise gradients | ||
234 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
|
7 | if(present(num_params)) then |
235 | 1 | num_params_ = num_params | |
236 | else | ||
237 | 6 | num_params_ = 1 | |
238 | end if | ||
239 | 7 | call optimiser%init_gradients(num_params_) | |
240 | |||
241 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 6 times.
|
14 | end function optimiser_setup_base |
242 | !!!############################################################################# | ||
243 | |||
244 | |||
245 | !!!############################################################################# | ||
246 | !!! initialise optimiser | ||
247 | !!!############################################################################# | ||
248 | 7 | subroutine init_base(this, num_params, regulariser, clip_dict) | |
249 | implicit none | ||
250 | class(base_optimiser_type), intent(inout) :: this | ||
251 | integer, intent(in) :: num_params | ||
252 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
253 | type(clip_type), optional, intent(in) :: clip_dict | ||
254 | |||
255 | |||
256 | !! apply regularisation | ||
257 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
7 | if(present(regulariser))then |
258 | ✗ | this%regularisation = .true. | |
259 | ✗ | if(allocated(this%regulariser)) deallocate(this%regulariser) | |
260 | ✗ | allocate(this%regulariser, source = regulariser) | |
261 | end if | ||
262 | |||
263 | !! apply clipping | ||
264 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
|
7 | if(present(clip_dict)) then |
265 | ✗ | this%clip_dict = clip_dict | |
266 | end if | ||
267 | |||
268 | !! initialise gradients | ||
269 | 7 | call this%init_gradients(num_params) | |
270 | |||
271 | 7 | end subroutine init_base | |
272 | !!!############################################################################# | ||
273 | |||
274 | |||
275 | !!!############################################################################# | ||
276 | !!! initialise gradients | ||
277 | !!!############################################################################# | ||
278 | 12 | pure subroutine init_gradients_base(this, num_params) | |
279 | implicit none | ||
280 | class(base_optimiser_type), intent(inout) :: this | ||
281 | integer, intent(in) :: num_params | ||
282 | |||
283 | !allocate(this%velocity(num_params), source=0._real12) | ||
284 | 12 | return | |
285 | end subroutine init_gradients_base | ||
286 | !!!############################################################################# | ||
287 | |||
288 | |||
289 | !!!############################################################################# | ||
290 | !!! minimise the loss function by applying gradients to the parameters | ||
291 | !!!############################################################################# | ||
292 |
2/4✓ Branch 0 taken 500 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 500 times.
✗ Branch 3 not taken.
|
500 | pure subroutine minimise_base(this, param, gradient) |
293 | implicit none | ||
294 | class(base_optimiser_type), intent(inout) :: this | ||
295 | real(real12), dimension(:), intent(inout) :: param | ||
296 | real(real12), dimension(:), intent(inout) :: gradient | ||
297 | |||
298 | real(real12) :: learning_rate | ||
299 | |||
300 | |||
301 | !! decay learning rate and update iteration | ||
302 | 500 | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
303 | 500 | this%iter = this%iter + 1 | |
304 | |||
305 | !! update parameters | ||
306 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 500 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 500 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 500 times.
✓ Branch 33 taken 15978 times.
✓ Branch 34 taken 500 times.
|
16478 | param = param - learning_rate * gradient |
307 | |||
308 | 500 | end subroutine minimise_base | |
309 | !!!############################################################################# | ||
310 | |||
311 | |||
312 | !!!##########################################################################!!! | ||
313 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
314 | !!!##########################################################################!!! | ||
315 | |||
316 | |||
317 | !!!############################################################################# | ||
318 | !!! set up optimiser | ||
319 | !!!############################################################################# | ||
320 | 3 | module function optimiser_setup_sgd( & | |
321 | learning_rate, momentum, & | ||
322 | nesterov, num_params, & | ||
323 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
324 | implicit none | ||
325 | real(real12), optional, intent(in) :: learning_rate, momentum | ||
326 | logical, optional, intent(in) :: nesterov | ||
327 | integer, optional, intent(in) :: num_params | ||
328 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
329 | type(clip_type), optional, intent(in) :: clip_dict | ||
330 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
331 | |||
332 | type(sgd_optimiser_type) :: optimiser | ||
333 | |||
334 | integer :: num_params_ | ||
335 | |||
336 | |||
337 | !! apply regularisation | ||
338 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(regulariser))then |
339 | 1 | optimiser%regularisation = .true. | |
340 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) |
341 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%regulariser, source = regulariser) |
342 | end if | ||
343 | |||
344 | !! apply clipping | ||
345 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | if(present(clip_dict)) optimiser%clip_dict = clip_dict |
346 | |||
347 | !! initialise general optimiser parameters | ||
348 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(learning_rate)) optimiser%learning_rate = learning_rate |
349 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | if(present(momentum)) optimiser%momentum = momentum |
350 | |||
351 | !! initialise learning rate decay | ||
352 |
3/4✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
3 | if(present(lr_decay)) then |
353 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) |
354 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%lr_decay, source = lr_decay) |
355 | else | ||
356 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | allocate(optimiser%lr_decay, source = base_lr_decay_type()) |
357 | end if | ||
358 | |||
359 | !! initialise nesterov boolean | ||
360 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | if(present(nesterov)) optimiser%nesterov = nesterov |
361 | |||
362 | !! initialise gradients | ||
363 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | if(present(num_params)) then |
364 | 1 | num_params_ = num_params | |
365 | else | ||
366 | 2 | num_params_ = 1 | |
367 | end if | ||
368 | 3 | call optimiser%init_gradients(num_params_) | |
369 | |||
370 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
6 | end function optimiser_setup_sgd |
371 | !!!############################################################################# | ||
372 | |||
373 | |||
374 | !!!############################################################################# | ||
375 | !!! initialise gradients | ||
376 | !!!############################################################################# | ||
377 | 5 | pure subroutine init_gradients_sgd(this, num_params) | |
378 | implicit none | ||
379 | class(sgd_optimiser_type), intent(inout) :: this | ||
380 | integer, intent(in) :: num_params | ||
381 | |||
382 | |||
383 | !! initialise gradients | ||
384 |
3/4✓ Branch 0 taken 2 times.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
|
5 | if(allocated(this%velocity)) deallocate(this%velocity) |
385 |
13/24✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✓ Branch 29 taken 23674 times.
✓ Branch 30 taken 5 times.
|
23679 | allocate(this%velocity(num_params), source=0._real12) |
386 | |||
387 | 5 | end subroutine init_gradients_sgd | |
388 | !!!############################################################################# | ||
389 | |||
390 | |||
391 | !!!############################################################################# | ||
392 | !!! minimise the loss function by applying gradients to the parameters | ||
393 | !!!############################################################################# | ||
394 |
2/4✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
|
9 | pure subroutine minimise_sgd(this, param, gradient) |
395 | implicit none | ||
396 | class(sgd_optimiser_type), intent(inout) :: this | ||
397 | real(real12), dimension(:), intent(inout) :: param | ||
398 | real(real12), dimension(:), intent(inout) :: gradient | ||
399 | |||
400 | real(real12) :: learning_rate | ||
401 | |||
402 | |||
403 | !! decay learning rate and update iteration | ||
404 | 9 | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
405 | 9 | this%iter = this%iter + 1 | |
406 | |||
407 | !! apply regularisation | ||
408 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
|
9 | if(this%regularisation) & |
409 | call this%regulariser%regularise( & | ||
410 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | param, gradient, learning_rate) |
411 | |||
412 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
|
9 | if(this%momentum.gt.1.E-8_real12)then !! adaptive learning method |
413 | 4 | this%velocity = this%momentum * this%velocity - & | |
414 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
|
11 | learning_rate * gradient |
415 | else !! standard learning method | ||
416 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✓ Branch 9 taken 8 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 8 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 16 times.
✓ Branch 16 taken 8 times.
|
24 | this%velocity = - learning_rate * gradient |
417 | end if | ||
418 | |||
419 | !! update parameters | ||
420 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 8 times.
|
9 | if(this%nesterov)then |
421 | ✗ | param = param + this%momentum * this%velocity - & | |
422 |
18/34✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 41 taken 10 times.
✓ Branch 42 taken 1 times.
|
11 | learning_rate * gradient |
423 | else | ||
424 |
14/26✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 8 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
✓ Branch 36 taken 16 times.
✓ Branch 37 taken 8 times.
|
24 | param = param + this%velocity |
425 | end if | ||
426 | |||
427 | 9 | end subroutine minimise_sgd | |
428 | !!!############################################################################# | ||
429 | |||
430 | |||
431 | !!!##########################################################################!!! | ||
432 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
433 | !!!##########################################################################!!! | ||
434 | |||
435 | |||
436 | !!!############################################################################# | ||
437 | !!! set up optimiser | ||
438 | !!!############################################################################# | ||
439 | 1 | module function optimiser_setup_rmsprop( & | |
440 | learning_rate, & | ||
441 | beta, epsilon, & | ||
442 | num_params, & | ||
443 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
444 | implicit none | ||
445 | real(real12), optional, intent(in) :: learning_rate | ||
446 | real(real12), optional, intent(in) :: beta, epsilon | ||
447 | integer, optional, intent(in) :: num_params | ||
448 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
449 | type(clip_type), optional, intent(in) :: clip_dict | ||
450 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
451 | |||
452 | type(rmsprop_optimiser_type) :: optimiser | ||
453 | |||
454 | integer :: num_params_ | ||
455 | |||
456 | |||
457 | !! apply regularisation | ||
458 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(regulariser))then |
459 | 1 | optimiser%regularisation = .true. | |
460 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) |
461 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%regulariser, source = regulariser) |
462 | end if | ||
463 | |||
464 | !! apply clipping | ||
465 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(clip_dict)) optimiser%clip_dict = clip_dict |
466 | |||
467 | !! initialise general optimiser parameters | ||
468 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(learning_rate)) optimiser%learning_rate = learning_rate |
469 | |||
470 | !! initialise learning rate decay | ||
471 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | if(present(lr_decay)) then |
472 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) |
473 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%lr_decay, source = lr_decay) |
474 | else | ||
475 | ✗ | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
476 | end if | ||
477 | |||
478 | !! initialise adam parameters | ||
479 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(beta)) optimiser%beta = beta |
480 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(epsilon)) optimiser%epsilon = epsilon |
481 | |||
482 | !! initialise gradients | ||
483 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(num_params)) then |
484 | 1 | num_params_ = num_params | |
485 | else | ||
486 | ✗ | num_params_ = 1 | |
487 | end if | ||
488 | 1 | call optimiser%init_gradients(num_params_) | |
489 | |||
490 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
2 | end function optimiser_setup_rmsprop |
491 | !!!############################################################################# | ||
492 | |||
493 | |||
494 | !!!############################################################################# | ||
495 | !!! initialise gradients | ||
496 | !!!############################################################################# | ||
497 | 1 | pure subroutine init_gradients_rmsprop(this, num_params) | |
498 | implicit none | ||
499 | class(rmsprop_optimiser_type), intent(inout) :: this | ||
500 | integer, intent(in) :: num_params | ||
501 | |||
502 | |||
503 | !! initialise gradients | ||
504 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1 | if(allocated(this%moving_avg)) deallocate(this%moving_avg) |
505 |
13/24✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✓ Branch 29 taken 10 times.
✓ Branch 30 taken 1 times.
|
11 | allocate(this%moving_avg(num_params), source=0._real12) !1.E-8_real12) |
506 | |||
507 | 1 | end subroutine init_gradients_rmsprop | |
508 | !!!############################################################################# | ||
509 | |||
510 | |||
511 | !!!############################################################################# | ||
512 | !!! minimise the loss function by applying gradients to the parameters | ||
513 | !!!############################################################################# | ||
514 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | pure subroutine minimise_rmsprop(this, param, gradient) |
515 | implicit none | ||
516 | class(rmsprop_optimiser_type), intent(inout) :: this | ||
517 | real(real12), dimension(:), intent(inout) :: param | ||
518 | real(real12), dimension(:), intent(inout) :: gradient | ||
519 | |||
520 | real(real12) :: learning_rate | ||
521 | |||
522 | |||
523 | !! decay learning rate and update iteration | ||
524 | 1 | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
525 | 1 | this%iter = this%iter + 1 | |
526 | |||
527 | !! apply regularisation | ||
528 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%regularisation) & |
529 | call this%regulariser%regularise( & | ||
530 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | param, gradient, learning_rate) |
531 | |||
532 | 4 | this%moving_avg = this%beta * this%moving_avg + & | |
533 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
|
11 | (1._real12 - this%beta) * gradient ** 2._real12 |
534 | |||
535 | !! update parameters | ||
536 | ✗ | param = param - learning_rate * gradient / & | |
537 |
18/34✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✓ Branch 37 taken 10 times.
✓ Branch 38 taken 1 times.
|
11 | (sqrt(this%moving_avg + this%epsilon)) |
538 | |||
539 | 1 | end subroutine minimise_rmsprop | |
540 | !!!############################################################################# | ||
541 | |||
542 | |||
543 | !!!##########################################################################!!! | ||
544 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
545 | !!!##########################################################################!!! | ||
546 | |||
547 | |||
548 | !!!############################################################################# | ||
549 | !!! set up optimiser | ||
550 | !!!############################################################################# | ||
551 | 1 | module function optimiser_setup_adagrad( & | |
552 | learning_rate, & | ||
553 | epsilon, & | ||
554 | num_params, & | ||
555 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
556 | implicit none | ||
557 | real(real12), optional, intent(in) :: learning_rate | ||
558 | real(real12), optional, intent(in) :: epsilon | ||
559 | integer, optional, intent(in) :: num_params | ||
560 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
561 | type(clip_type), optional, intent(in) :: clip_dict | ||
562 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
563 | |||
564 | type(adagrad_optimiser_type) :: optimiser | ||
565 | |||
566 | integer :: num_params_ | ||
567 | |||
568 | |||
569 | !! apply regularisation | ||
570 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(regulariser))then |
571 | 1 | optimiser%regularisation = .true. | |
572 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) |
573 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%regulariser, source = regulariser) |
574 | end if | ||
575 | |||
576 | !! apply clipping | ||
577 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(clip_dict)) optimiser%clip_dict = clip_dict |
578 | |||
579 | !! initialise general optimiser parameters | ||
580 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(learning_rate)) optimiser%learning_rate = learning_rate |
581 | |||
582 | !! initialise learning rate decay | ||
583 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | if(present(lr_decay)) then |
584 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
1 | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) |
585 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | allocate(optimiser%lr_decay, source = lr_decay) |
586 | else | ||
587 | ✗ | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
588 | end if | ||
589 | |||
590 | !! initialise adam parameters | ||
591 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(epsilon)) optimiser%epsilon = epsilon |
592 | |||
593 | !! initialise gradients | ||
594 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(num_params)) then |
595 | 1 | num_params_ = num_params | |
596 | else | ||
597 | ✗ | num_params_ = 1 | |
598 | end if | ||
599 | 1 | call optimiser%init_gradients(num_params_) | |
600 | |||
601 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
2 | end function optimiser_setup_adagrad |
602 | !!!############################################################################# | ||
603 | |||
604 | |||
605 | !!!############################################################################# | ||
606 | !!! initialise gradients | ||
607 | !!!############################################################################# | ||
608 | 1 | pure subroutine init_gradients_adagrad(this, num_params) | |
609 | implicit none | ||
610 | class(adagrad_optimiser_type), intent(inout) :: this | ||
611 | integer, intent(in) :: num_params | ||
612 | |||
613 | |||
614 | !! initialise gradients | ||
615 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1 | if(allocated(this%sum_squares)) deallocate(this%sum_squares) |
616 |
13/24✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✓ Branch 29 taken 10 times.
✓ Branch 30 taken 1 times.
|
11 | allocate(this%sum_squares(num_params), source=0._real12) !1.E-8_real12) |
617 | |||
618 | 1 | end subroutine init_gradients_adagrad | |
619 | !!!############################################################################# | ||
620 | |||
621 | |||
622 | !!!############################################################################# | ||
623 | !!! minimise the loss function by applying gradients to the parameters | ||
624 | !!!############################################################################# | ||
625 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | pure subroutine minimise_adagrad(this, param, gradient) |
626 | implicit none | ||
627 | class(adagrad_optimiser_type), intent(inout) :: this | ||
628 | real(real12), dimension(:), intent(inout) :: param | ||
629 | real(real12), dimension(:), intent(inout) :: gradient | ||
630 | |||
631 | real(real12) :: learning_rate | ||
632 | |||
633 | |||
634 | !! decay learning rate and update iteration | ||
635 | 1 | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
636 | 1 | this%iter = this%iter + 1 | |
637 | |||
638 | !! apply regularisation | ||
639 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%regularisation) & |
640 | call this%regulariser%regularise( & | ||
641 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | param, this%sum_squares, learning_rate) |
642 | |||
643 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 10 times.
✓ Branch 31 taken 1 times.
|
11 | this%sum_squares = this%sum_squares + gradient ** 2._real12 |
644 | |||
645 | !! update parameters | ||
646 | ✗ | param = param - learning_rate * gradient / & | |
647 |
18/34✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✓ Branch 37 taken 10 times.
✓ Branch 38 taken 1 times.
|
11 | (sqrt(this%sum_squares + this%epsilon)) |
648 | |||
649 | 1 | end subroutine minimise_adagrad | |
650 | !!!############################################################################# | ||
651 | |||
652 | |||
653 | !!!##########################################################################!!! | ||
654 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
655 | !!!##########################################################################!!! | ||
656 | |||
657 | |||
658 | !!!############################################################################# | ||
659 | !!! set up optimiser | ||
660 | !!!############################################################################# | ||
661 | 3 | module function optimiser_setup_adam( & | |
662 | learning_rate, & | ||
663 | beta1, beta2, epsilon, & | ||
664 | num_params, & | ||
665 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
666 | implicit none | ||
667 | real(real12), optional, intent(in) :: learning_rate | ||
668 | real(real12), optional, intent(in) :: beta1, beta2, epsilon | ||
669 | integer, optional, intent(in) :: num_params | ||
670 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
671 | type(clip_type), optional, intent(in) :: clip_dict | ||
672 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
673 | |||
674 | type(adam_optimiser_type) :: optimiser | ||
675 | |||
676 | integer :: num_params_ | ||
677 | |||
678 | |||
679 | !! apply regularisation | ||
680 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(regulariser))then |
681 | 3 | optimiser%regularisation = .true. | |
682 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
3 | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) |
683 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | allocate(optimiser%regulariser, source = regulariser) |
684 | end if | ||
685 | |||
686 | !! apply clipping | ||
687 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
|
3 | if(present(clip_dict)) optimiser%clip_dict = clip_dict |
688 | |||
689 | !! initialise general optimiser parameters | ||
690 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(learning_rate)) optimiser%learning_rate = learning_rate |
691 | |||
692 | !! initialise learning rate decay | ||
693 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | if(present(lr_decay)) then |
694 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
|
3 | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) |
695 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
3 | allocate(optimiser%lr_decay, source = lr_decay) |
696 | else | ||
697 | ✗ | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
698 | end if | ||
699 | |||
700 | !! initialise adam parameters | ||
701 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(beta1)) optimiser%beta1 = beta1 |
702 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(beta2)) optimiser%beta2 = beta2 |
703 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(epsilon)) optimiser%epsilon = epsilon |
704 | |||
705 | !! initialise gradients | ||
706 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(present(num_params)) then |
707 | 3 | num_params_ = num_params | |
708 | else | ||
709 | ✗ | num_params_ = 1 | |
710 | end if | ||
711 | 3 | call optimiser%init_gradients(num_params_) | |
712 | |||
713 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
6 | end function optimiser_setup_adam |
714 | !!!############################################################################# | ||
715 | |||
716 | |||
717 | !!!############################################################################# | ||
718 | !!! initialise gradients | ||
719 | !!!############################################################################# | ||
720 | 3 | pure subroutine init_gradients_adam(this, num_params) | |
721 | implicit none | ||
722 | class(adam_optimiser_type), intent(inout) :: this | ||
723 | integer, intent(in) :: num_params | ||
724 | |||
725 | |||
726 | !! initialise gradients | ||
727 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
3 | if(allocated(this%m)) deallocate(this%m) |
728 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
3 | if(allocated(this%v)) deallocate(this%v) |
729 |
13/24✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✓ Branch 29 taken 30 times.
✓ Branch 30 taken 3 times.
|
33 | allocate(this%m(num_params), source=0._real12) |
730 |
13/24✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 3 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✓ Branch 29 taken 30 times.
✓ Branch 30 taken 3 times.
|
33 | allocate(this%v(num_params), source=0._real12) |
731 | |||
732 | 3 | end subroutine init_gradients_adam | |
733 | !!!############################################################################# | ||
734 | |||
735 | |||
736 | !!!############################################################################# | ||
737 | !!! minimise the loss function by applying gradients to the parameters | ||
738 | !!!############################################################################# | ||
739 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | pure subroutine minimise_adam(this, param, gradient) |
740 | implicit none | ||
741 | class(adam_optimiser_type), intent(inout) :: this | ||
742 | real(real12), dimension(:), intent(inout) :: param | ||
743 | real(real12), dimension(:), intent(inout) :: gradient | ||
744 | |||
745 | real(real12) :: learning_rate | ||
746 | |||
747 | |||
748 | !! decay learning rate and update iteration | ||
749 | 3 | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
750 | 3 | this%iter = this%iter + 1 | |
751 | |||
752 | !! apply regularisation | ||
753 |
1/2✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
|
3 | if(this%regularisation) & |
754 | call this%regulariser%regularise( & | ||
755 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
|
3 | param, gradient, learning_rate) |
756 | |||
757 | !! adaptive learning method | ||
758 | 12 | this%m = this%beta1 * this%m + & | |
759 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
|
33 | (1._real12 - this%beta1) * gradient |
760 | 12 | this%v = this%beta2 * this%v + & | |
761 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
|
33 | (1._real12 - this%beta2) * gradient ** 2._real12 |
762 | |||
763 | !! update parameters | ||
764 | associate( & | ||
765 | 12 | m_hat => this%m / (1._real12 - this%beta1**this%iter), & | |
766 | 12 | v_hat => this%v / (1._real12 - this%beta2**this%iter) ) | |
767 |
14/24✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✓ Branch 14 taken 30 times.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 3 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3 times.
✓ Branch 30 taken 30 times.
✓ Branch 31 taken 3 times.
|
63 | select type(regulariser => this%regulariser) |
768 | type is (l2_regulariser_type) | ||
769 | 2 | select case(regulariser%decoupled) | |
770 | case(.true.) | ||
771 | ✗ | param = param - & | |
772 | learning_rate * & | ||
773 | ✗ | ( m_hat / (sqrt(v_hat) + this%epsilon) + & | |
774 |
23/44✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✓ Branch 48 taken 10 times.
✓ Branch 49 taken 1 times.
|
11 | regulariser%l2 * param ) |
775 | case(.false.) | ||
776 | ✗ | param = param + & | |
777 | learning_rate * & | ||
778 | ✗ | ( ( m_hat + regulariser%l2 * param ) / & | |
779 |
25/46✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✓ Branch 50 taken 10 times.
✓ Branch 51 taken 1 times.
|
12 | (sqrt(v_hat) + this%epsilon) ) |
780 | end select | ||
781 | class default | ||
782 | ✗ | param = param + & | |
783 | ✗ | learning_rate * ( ( m_hat + param ) / & | |
784 |
23/44✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✓ Branch 48 taken 10 times.
✓ Branch 49 taken 1 times.
|
11 | (sqrt(v_hat) + this%epsilon) ) |
785 | end select | ||
786 | end associate | ||
787 | |||
788 | 3 | end subroutine minimise_adam | |
789 | !!!############################################################################# | ||
790 | |||
791 |
21/63✓ Branch 0 taken 17 times.
✓ Branch 1 taken 16 times.
✓ Branch 2 taken 7 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 15 times.
✓ Branch 5 taken 15 times.
✗ Branch 6 not taken.
✓ Branch 8 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 15 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 15 times.
✓ Branch 45 taken 15 times.
✓ Branch 46 taken 15 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 15 times.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 52 not taken.
✗ Branch 53 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 64 not taken.
✗ Branch 65 not taken.
✓ Branch 67 taken 15 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 3 times.
✓ Branch 70 taken 12 times.
✓ Branch 71 taken 2 times.
✓ Branch 72 taken 6 times.
✗ Branch 73 not taken.
✓ Branch 74 taken 11 times.
✗ Branch 75 not taken.
✗ Branch 76 not taken.
✓ Branch 77 taken 9 times.
✗ Branch 78 not taken.
✓ Branch 79 taken 9 times.
|
93 | end module optimiser |
792 | !!!############################################################################# | ||
793 | |||
794 |