GCC Code Coverage Report


Directory: src/athena/
File: athena_optimiser.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__optimiser
2 !! Module containing implementations of optimisation methods
3 !!
4 !! This module implements gradient-based optimisers for training neural networks
5 !! by minimizing loss functions through iterative parameter updates.
6 !!
7 !! Implemented optimisers:
8 !!
9 !! SGD (Stochastic Gradient Descent):
10 !! θ_{t+1} = θ_t - η * ∇L(θ_t)
11 !! Simple, reliable baseline optimiser
12 !!
13 !! SGD with Momentum:
14 !! v_{t+1} = μ*v_t + ∇L(θ_t)
15 !! θ_{t+1} = θ_t - η * v_{t+1}
16 !! Accelerates convergence, dampens oscillations
17 !!
18 !! RMSProp:
19 !! s_{t+1} = β*s_t + (1-β)*[∇L(θ_t)]²
20 !! θ_{t+1} = θ_t - η * ∇L(θ_t) / sqrt(s_{t+1} + ε)
21 !! Adapts learning rate per parameter, good for non-stationary objectives
22 !!
23 !! Adagrad:
24 !! s_{t+1} = s_t + [∇L(θ_t)]²
25 !! θ_{t+1} = θ_t - η * ∇L(θ_t) / sqrt(s_{t+1} + ε)
26 !! Adapts learning rate based on historical gradients
27 !!
28 !! Adam (Adaptive Moment Estimation):
29 !! m_{t+1} = β₁*m_t + (1-β₁)*∇L(θ_t) [first moment]
30 !! v_{t+1} = β₂*v_t + (1-β₂)*[∇L(θ_t)]² [second moment]
31 !! m̂ = m_{t+1}/(1-β₁^t), v̂ = v_{t+1}/(1-β₂^t) [bias correction]
32 !! θ_{t+1} = θ_t - η * m̂ / (sqrt(v̂) + ε)
33 !! Combines momentum and adaptive learning rates, most popular choice
34 !!
35 !! L-BFGS (Limited-memory BFGS):
36 !! Quasi-Newton method approximating Hessian inverse
37 !! Good for small-medium sized problems, smooth objectives
38 !!
39 !! where η is learning rate, μ is momentum, β/β₁/β₂ are decay rates
40 !!
41 !! Attribution statement:
42 !! The following module is based on code from the neural-fortran library
43 !! https://github.com/modern-fortran/neural-fortran
44 !! The implementation of optimiser_base_type, sgd_optimiser_type,
45 !! rmsprop_optimiser_type, adagrad_optimiser_type, and adam_optimiser_type
46 !! are based on the corresponding types from neural-fortran
47 use coreutils, only: real32, stop_program
48 use athena__clipper, only: clip_type
49 use athena__regulariser, only: base_regulariser_type, l2_regulariser_type
50 use athena__learning_rate_decay, only: base_lr_decay_type
51 implicit none
52
53
54 private
55
56 public :: base_optimiser_type
57 public :: sgd_optimiser_type
58 public :: rmsprop_optimiser_type
59 public :: adagrad_optimiser_type
60 public :: adam_optimiser_type
61
62
63 !-------------------------------------------------------------------------------
64
65 type :: base_optimiser_type
66 !! Base optimiser type
67 character(len=20) :: name
68 !! Name of the optimiser
69 integer :: iter = 0
70 !! Iteration number
71 integer :: epoch = 0
72 !! Epoch number
73 real(real32) :: learning_rate = 0.01_real32
74 !! Learning rate hyperparameter
75 logical :: regularisation = .false.
76 !! Apply regularisation
77 class(base_regulariser_type), allocatable :: regulariser
78 !! Regularisation method
79 class(base_lr_decay_type), allocatable :: lr_decay
80 !! Learning rate decay method
81 type(clip_type) :: clip_dict
82 !! Clipping dictionary
83 contains
84 procedure, pass(this) :: init => init_base
85 !! Initialise base optimiser
86 procedure, pass(this) :: print_to_unit => print_to_unit_base
87 !! Print base optimiser information
88 procedure, pass(this) :: read => read_base
89 !! Read base optimiser information
90 procedure, pass(this) :: init_gradients => init_gradients_base
91 !! Initialise gradients
92 procedure, pass(this) :: minimise => minimise_base
93 !! Apply gradients to parameters to minimise loss using base optimiser
94 end type base_optimiser_type
95
96 interface base_optimiser_type
97 !! Interface for setting up the base optimiser
98 module function optimiser_setup_base( &
99 learning_rate, num_params, &
100 regulariser, clip_dict, lr_decay) result(optimiser)
101 !! Set up the base optimiser
102 real(real32), optional, intent(in) :: learning_rate
103 !! Learning rate
104 integer, optional, intent(in) :: num_params
105 !! Number of parameters
106 class(base_regulariser_type), optional, intent(in) :: regulariser
107 !! Regularisation method
108 type(clip_type), optional, intent(in) :: clip_dict
109 !! Clipping dictionary
110 class(base_lr_decay_type), optional, intent(in) :: lr_decay
111 !! Learning rate decay method
112 type(base_optimiser_type) :: optimiser
113 !! Instance of the base optimiser
114 end function optimiser_setup_base
115 end interface base_optimiser_type
116
117 !-------------------------------------------------------------------------------
118
119 type, extends(base_optimiser_type) :: sgd_optimiser_type
120 !! Stochastic gradient descent optimiser type
121 logical :: nesterov = .false.
122 !! Nesterov momentum
123 real(real32) :: momentum = 0._real32
124 !! Fraction of momentum-based learning
125 real(real32), allocatable, dimension(:) :: velocity
126 !! Velocity for momentum
127 contains
128 procedure, pass(this) :: init_gradients => init_gradients_sgd
129 !! Initialise gradients for SGD
130 procedure, pass(this) :: minimise => minimise_sgd
131 !! Apply gradients to parameters to minimise loss using SGD optimiser
132 end type sgd_optimiser_type
133
134 interface sgd_optimiser_type
135 !! Interface for setting up the SGD optimiser
136 module function optimiser_setup_sgd( &
137 learning_rate, momentum, &
138 nesterov, num_params, &
139 regulariser, clip_dict, lr_decay) result(optimiser)
140 !! Set up the SGD optimiser
141 real(real32), optional, intent(in) :: learning_rate, momentum
142 !! Learning rate and momentum
143 logical, optional, intent(in) :: nesterov
144 !! Nesterov momentum
145 integer, optional, intent(in) :: num_params
146 !! Number of parameters
147 class(base_regulariser_type), optional, intent(in) :: regulariser
148 !! Regularisation method
149 type(clip_type), optional, intent(in) :: clip_dict
150 !! Clipping dictionary
151 class(base_lr_decay_type), optional, intent(in) :: lr_decay
152 !! Learning rate decay method
153 type(sgd_optimiser_type) :: optimiser
154 !! Instance of the SGD optimiser
155 end function optimiser_setup_sgd
156 end interface sgd_optimiser_type
157
158 !-------------------------------------------------------------------------------
159
160 type, extends(base_optimiser_type) :: rmsprop_optimiser_type
161 !! RMSprop optimiser type
162 real(real32) :: beta = 0._real32
163 !! Beta parameter
164 real(real32) :: epsilon = 1.E-8_real32
165 !! Epsilon parameter
166 real(real32), allocatable, dimension(:) :: moving_avg
167 !! Moving average
168 contains
169 procedure, pass(this) :: init_gradients => init_gradients_rmsprop
170 !! Initialise gradients for RMSprop
171 procedure, pass(this) :: minimise => minimise_rmsprop
172 !! Apply gradients to parameters to minimise loss using RMSprop optimiser
173 end type rmsprop_optimiser_type
174
175 interface rmsprop_optimiser_type
176 !! Interface for setting up the RMSprop optimiser
177 module function optimiser_setup_rmsprop( &
178 learning_rate, beta, &
179 epsilon, num_params, &
180 regulariser, clip_dict, lr_decay) result(optimiser)
181 !! Set up the RMSprop optimiser
182 real(real32), optional, intent(in) :: learning_rate, beta, epsilon
183 !! Learning rate, beta, and epsilon
184 integer, optional, intent(in) :: num_params
185 !! Number of parameters
186 class(base_regulariser_type), optional, intent(in) :: regulariser
187 !! Regularisation method
188 type(clip_type), optional, intent(in) :: clip_dict
189 !! Clipping dictionary
190 class(base_lr_decay_type), optional, intent(in) :: lr_decay
191 !! Learning rate decay method
192 type(rmsprop_optimiser_type) :: optimiser
193 !! Instance of the RMSprop optimiser
194 end function optimiser_setup_rmsprop
195 end interface rmsprop_optimiser_type
196
197 !-------------------------------------------------------------------------------
198
199 type, extends(base_optimiser_type) :: adagrad_optimiser_type
200 !! Adagrad optimiser type
201 real(real32) :: epsilon = 1.E-8_real32
202 !! Epsilon parameter
203 real(real32), allocatable, dimension(:) :: sum_squares
204 !! Sum of squares of gradients
205 contains
206 procedure, pass(this) :: init_gradients => init_gradients_adagrad
207 !! Initialise gradients for Adagrad
208 procedure, pass(this) :: minimise => minimise_adagrad
209 !! Apply gradients to parameters to minimise loss using Adagrad optimiser
210 end type adagrad_optimiser_type
211
212 interface adagrad_optimiser_type
213 !! Interface for setting up the Adagrad optimiser
214 module function optimiser_setup_adagrad( &
215 learning_rate, &
216 epsilon, num_params, &
217 regulariser, clip_dict, lr_decay) result(optimiser)
218 !! Set up the Adagrad optimiser
219 real(real32), optional, intent(in) :: learning_rate, epsilon
220 !! Learning rate and epsilon
221 integer, optional, intent(in) :: num_params
222 !! Number of parameters
223 class(base_regulariser_type), optional, intent(in) :: regulariser
224 !! Regularisation method
225 type(clip_type), optional, intent(in) :: clip_dict
226 !! Clipping dictionary
227 class(base_lr_decay_type), optional, intent(in) :: lr_decay
228 !! Learning rate decay method
229 type(adagrad_optimiser_type) :: optimiser
230 !! Instance of the Adagrad optimiser
231 end function optimiser_setup_adagrad
232 end interface adagrad_optimiser_type
233
234 !-------------------------------------------------------------------------------
235
236 type, extends(base_optimiser_type) :: adam_optimiser_type
237 !! Adam optimiser type
238 real(real32) :: beta1 = 0.9_real32
239 !! Beta1 parameter
240 real(real32) :: beta2 = 0.999_real32
241 !! Beta2 parameter
242 real(real32) :: epsilon = 1.E-8_real32
243 !! Epsilon parameter
244 real(real32), allocatable, dimension(:) :: m
245 !! First moment estimate
246 real(real32), allocatable, dimension(:) :: v
247 !! Second moment estimate
248 contains
249 procedure, pass(this) :: init_gradients => init_gradients_adam
250 !! Initialise gradients for Adam
251 procedure, pass(this) :: minimise => minimise_adam
252 !! Apply gradients to parameters to minimise loss using Adam optimiser
253 end type adam_optimiser_type
254
255 interface adam_optimiser_type
256 !! Interface for setting up the Adam optimiser
257 module function optimiser_setup_adam( &
258 learning_rate, &
259 beta1, beta2, epsilon, &
260 num_params, &
261 regulariser, clip_dict, lr_decay) result(optimiser)
262 !! Set up the Adam optimiser
263 real(real32), optional, intent(in) :: learning_rate
264 !! Learning rate
265 real(real32), optional, intent(in) :: beta1, beta2, epsilon
266 !! Beta1, beta2, and epsilon
267 integer, optional, intent(in) :: num_params
268 !! Number of parameters
269 class(base_regulariser_type), optional, intent(in) :: regulariser
270 !! Regularisation method
271 type(clip_type), optional, intent(in) :: clip_dict
272 !! Clipping dictionary
273 class(base_lr_decay_type), optional, intent(in) :: lr_decay
274 !! Learning rate decay method
275 type(adam_optimiser_type) :: optimiser
276 !! Instance of the Adam optimiser
277 end function optimiser_setup_adam
278 end interface adam_optimiser_type
279
280
281
282 contains
283
284 !###############################################################################
285 module function optimiser_setup_base( &
286 learning_rate, num_params, &
287 regulariser, clip_dict, lr_decay &
288 ) result(optimiser)
289 !! Set up the base optimiser
290 implicit none
291
292 ! Arguments
293 real(real32), optional, intent(in) :: learning_rate
294 !! Learning rate
295 integer, optional, intent(in) :: num_params
296 !! Number of parameters
297 class(base_regulariser_type), optional, intent(in) :: regulariser
298 !! Regularisation method
299 type(clip_type), optional, intent(in) :: clip_dict
300 !! Clipping dictionary
301 class(base_lr_decay_type), optional, intent(in) :: lr_decay
302 !! Learning rate decay method
303
304 type(base_optimiser_type) :: optimiser
305 !! Instance of the base optimiser
306
307 ! Local variables
308 integer :: num_params_
309 !! Number of parameters
310
311
312 ! Initialise optimiser name
313 optimiser%name = "base"
314
315 ! Apply regularisation
316 if(present(regulariser))then
317 optimiser%regularisation = .true.
318 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
319 allocate(optimiser%regulariser, source = regulariser)
320 end if
321
322 ! Apply clipping
323 if(present(clip_dict)) optimiser%clip_dict = clip_dict
324
325 ! Initialise general optimiser parameters
326 if(present(learning_rate)) optimiser%learning_rate = learning_rate
327
328 ! Initialise learning rate decay
329 if(present(lr_decay)) then
330 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
331 allocate(optimiser%lr_decay, source = lr_decay)
332 else
333 allocate(optimiser%lr_decay, source = base_lr_decay_type())
334 end if
335
336 ! Initialise gradients
337 if(present(num_params)) then
338 num_params_ = num_params
339 else
340 num_params_ = 1
341 end if
342 call optimiser%init_gradients(num_params_)
343 end function optimiser_setup_base
344 !###############################################################################
345
346
347 !###############################################################################
348 subroutine init_base(this, num_params, regulariser, clip_dict)
349 !! Initialise base optimiser
350 implicit none
351
352 ! Arguments
353 class(base_optimiser_type), intent(inout) :: this
354 !! Instance of the base optimiser
355 integer, intent(in) :: num_params
356 !! Number of parameters
357 class(base_regulariser_type), optional, intent(in) :: regulariser
358 !! Regularisation method
359 type(clip_type), optional, intent(in) :: clip_dict
360 !! Clipping dictionary
361
362
363 ! Apply regularisation
364 if(present(regulariser))then
365 this%regularisation = .true.
366 if(allocated(this%regulariser)) deallocate(this%regulariser)
367 allocate(this%regulariser, source = regulariser)
368 end if
369
370 ! Apply clipping
371 if(present(clip_dict)) this%clip_dict = clip_dict
372
373 ! Initialise gradients
374 call this%init_gradients(num_params)
375 end subroutine init_base
376 !###############################################################################
377
378
379 !###############################################################################
380 pure subroutine init_gradients_base(this, num_params)
381 !! Initialise gradients for base optimiser
382 implicit none
383
384 ! Arguments
385 class(base_optimiser_type), intent(inout) :: this
386 !! Instance of the base optimiser
387 integer, intent(in) :: num_params
388 !! Number of parameters
389
390 !allocate(this%velocity(num_params), source=0._real32)
391 end subroutine init_gradients_base
392 !###############################################################################
393
394
395 !###############################################################################
396 pure subroutine minimise_base(this, param, gradient)
397 !! Apply gradients to parameters to minimise loss using base optimiser
398 implicit none
399
400 ! Arguments
401 class(base_optimiser_type), intent(inout) :: this
402 !! Instance of the base optimiser
403 real(real32), dimension(:), intent(inout) :: param
404 !! Parameters
405 real(real32), dimension(:), intent(inout) :: gradient
406 !! Gradients
407
408 ! Local variables
409 real(real32) :: learning_rate
410 !! Learning rate
411
412
413 ! Decay learning rate and update iteration
414 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
415
416 ! Update parameters
417 param = param - learning_rate * gradient
418 end subroutine minimise_base
419 !###############################################################################
420
421
422 !##############################################################################!
423 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
424 !##############################################################################!
425
426
427 !###############################################################################
428 subroutine print_to_unit_base(this, unit)
429 !! Print base optimiser information
430 implicit none
431
432 ! Arguments
433 class(base_optimiser_type), intent(in) :: this
434 !! Instance of the base optimiser
435 integer, intent(in) :: unit
436 !! File unit
437
438
439 write(unit,'(6X,"NAME = ",A)') this%name
440 write(unit,'(6X,"LEARNING_RATE = ",F10.5)') this%learning_rate
441 write(unit,'(6X,"ITERATION = ",I10)') this%iter
442 write(unit,'(6X,"EPOCH = ",I10)') this%epoch
443 write(unit,'(6X,"REGULARISATION = ",L1)') this%regularisation
444
445 end subroutine print_to_unit_base
446 !###############################################################################
447
448
449 !###############################################################################
450 subroutine read_base(this, unit)
451 !! Read base optimiser information
452 use athena__tools_infile, only: assign_val, assign_vec
453 use coreutils, only: to_lower, to_upper, icount
454 implicit none
455
456 ! Arguments
457 class(base_optimiser_type), intent(inout) :: this
458 !! Instance of the base optimiser
459 integer, intent(in) :: unit
460 !! File unit
461
462 ! Local variables
463 integer :: stat
464 !! File status
465 integer :: itmp1
466 !! Temporary integer
467 character(256) :: buffer, tag, err_msg
468 !! Buffer for reading lines, tag for identifying lines, error message
469
470
471 ! Loop over tags in layer card
472 !---------------------------------------------------------------------------
473 tag_loop: do
474
475 ! Check for end of file
476 !------------------------------------------------------------------------
477 read(unit,'(A)',iostat=stat) buffer
478 if(stat.ne.0)then
479 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
480 to_upper(this%name)
481 call stop_program(err_msg)
482 return
483 end if
484 if(trim(adjustl(buffer)).eq."") cycle tag_loop
485
486 ! Check for end of layer card
487 !------------------------------------------------------------------------
488 if(trim(adjustl(buffer)).eq."END OPTIMISER")then
489 backspace(unit)
490 exit tag_loop
491 end if
492
493 tag=trim(adjustl(buffer))
494 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
495
496 ! Read parameters from save file
497 !------------------------------------------------------------------------
498 select case(trim(tag))
499 case("NAME")
500 call assign_val(buffer, this%name, itmp1)
501 case("LEARNING_RATE")
502 call assign_val(buffer, this%learning_rate, itmp1)
503 case("ITERATION")
504 call assign_val(buffer, this%iter, itmp1)
505 case("EPOCH")
506 call assign_val(buffer, this%epoch, itmp1)
507 case("REGULARISATION")
508 call assign_val(buffer, this%regularisation, itmp1)
509 case default
510 ! Don't look for "e" due to scientific notation of numbers
511 ! ... i.e. exponent (E+00)
512 if(scan(to_lower(trim(adjustl(buffer))),&
513 'abcdfghijklmnopqrstuvwxyz').eq.0)then
514 cycle tag_loop
515 elseif(tag(:3).eq.'END')then
516 cycle tag_loop
517 end if
518 write(err_msg,'("Unrecognised line in input file: ",A)') &
519 trim(adjustl(buffer))
520 call stop_program(err_msg)
521 return
522 end select
523 end do tag_loop
524
525
526 ! Check for end of layer card
527 !---------------------------------------------------------------------------
528 read(unit,'(A)') buffer
529 if(trim(adjustl(buffer)).ne."END OPTIMISER")then
530 write(0,*) trim(adjustl(buffer))
531 write(err_msg,'("END OPTIMISER not where expected")')
532 call stop_program(err_msg)
533 return
534 end if
535
536 end subroutine read_base
537 !################################################################################
538
539
540 !##############################################################################!
541 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
542 !##############################################################################!
543
544
545 !###############################################################################
546 module function optimiser_setup_sgd( &
547 learning_rate, momentum, &
548 nesterov, num_params, &
549 regulariser, clip_dict, lr_decay) result(optimiser)
550 !! Set up the SGD optimiser
551 implicit none
552
553 ! Arguments
554 real(real32), optional, intent(in) :: learning_rate, momentum
555 !! Learning rate and momentum
556 logical, optional, intent(in) :: nesterov
557 !! Nesterov momentum
558 integer, optional, intent(in) :: num_params
559 !! Number of parameters
560 class(base_regulariser_type), optional, intent(in) :: regulariser
561 !! Regularisation method
562 type(clip_type), optional, intent(in) :: clip_dict
563 !! Clipping dictionary
564 class(base_lr_decay_type), optional, intent(in) :: lr_decay
565 !! Learning rate decay method
566
567 type(sgd_optimiser_type) :: optimiser
568 !! Instance of the SGD optimiser
569
570 ! Local variables
571 integer :: num_params_
572 !! Number of parameters
573
574
575 ! Initialise optimiser name
576 optimiser%name = "sgd"
577
578 ! Apply regularisation
579 if(present(regulariser))then
580 optimiser%regularisation = .true.
581 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
582 allocate(optimiser%regulariser, source = regulariser)
583 end if
584
585 ! Apply clipping
586 if(present(clip_dict)) optimiser%clip_dict = clip_dict
587
588 ! Initialise general optimiser parameters
589 if(present(learning_rate)) optimiser%learning_rate = learning_rate
590 if(present(momentum)) optimiser%momentum = momentum
591
592 ! Initialise learning rate decay
593 if(present(lr_decay)) then
594 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
595 allocate(optimiser%lr_decay, source = lr_decay)
596 else
597 allocate(optimiser%lr_decay, source = base_lr_decay_type())
598 end if
599
600 ! Initialise nesterov boolean
601 if(present(nesterov)) optimiser%nesterov = nesterov
602
603 ! Initialise gradients
604 if(present(num_params)) then
605 num_params_ = num_params
606 else
607 num_params_ = 1
608 end if
609 call optimiser%init_gradients(num_params_)
610 end function optimiser_setup_sgd
611 !###############################################################################
612
613
614 !###############################################################################
615 pure subroutine init_gradients_sgd(this, num_params)
616 !! Initialise gradients for SGD optimiser
617 implicit none
618
619 ! Arguments
620 class(sgd_optimiser_type), intent(inout) :: this
621 !! Instance of the SGD optimiser
622 integer, intent(in) :: num_params
623 !! Number of parameters
624
625
626 ! Initialise gradients
627 if(allocated(this%velocity)) deallocate(this%velocity)
628 allocate(this%velocity(num_params), source=0._real32)
629 end subroutine init_gradients_sgd
630 !###############################################################################
631
632
633 !###############################################################################
634 pure subroutine minimise_sgd(this, param, gradient)
635 !! Apply gradients to parameters to minimise loss using SGD optimiser
636 implicit none
637
638 ! Arguments
639 class(sgd_optimiser_type), intent(inout) :: this
640 !! Instance of the SGD optimiser
641 real(real32), dimension(:), intent(inout) :: param
642 !! Parameters
643 real(real32), dimension(:), intent(inout) :: gradient
644 !! Gradients
645
646 ! Local variables
647 real(real32) :: learning_rate
648 !! Learning rate
649
650
651 ! Decay learning rate and update iteration
652 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
653
654 ! Apply regularisation
655 if(this%regularisation) &
656 call this%regulariser%regularise( param, gradient, learning_rate )
657
658 gradient = - learning_rate * gradient
659 ! Update parameters
660 if(this%momentum.gt.1.E-8_real32)then
661 !! Adaptive learning method
662 this%velocity = this%momentum * this%velocity + gradient
663 if(this%nesterov)then
664 param = param + this%momentum * this%velocity + gradient
665 else
666 param = param + this%velocity
667 end if
668 else
669 ! Standard learning method
670 this%velocity = gradient
671 param = param + this%velocity
672 end if
673 end subroutine minimise_sgd
674 !###############################################################################
675
676
677 !##############################################################################!
678 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
679 !##############################################################################!
680
681
682 !###############################################################################
683 module function optimiser_setup_rmsprop( &
684 learning_rate, beta, epsilon, &
685 num_params, regulariser, clip_dict, lr_decay &
686 ) result(optimiser)
687 !! Set up the RMSprop optimiser
688 implicit none
689
690 ! Arguments
691 real(real32), optional, intent(in) :: learning_rate
692 !! Learning rate
693 real(real32), optional, intent(in) :: beta, epsilon
694 !! Beta and epsilon
695 integer, optional, intent(in) :: num_params
696 !! Number of parameters
697 class(base_regulariser_type), optional, intent(in) :: regulariser
698 !! Regularisation method
699 type(clip_type), optional, intent(in) :: clip_dict
700 !! Clipping dictionary
701 class(base_lr_decay_type), optional, intent(in) :: lr_decay
702 !! Learning rate decay method
703
704 type(rmsprop_optimiser_type) :: optimiser
705 !! Instance of the RMSprop optimiser
706
707 ! Local variables
708 integer :: num_params_
709 !! Number of parameters
710
711
712 ! Initialise optimiser name
713 optimiser%name = "rmsprop"
714
715 ! Apply regularisation
716 if(present(regulariser))then
717 optimiser%regularisation = .true.
718 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
719 allocate(optimiser%regulariser, source = regulariser)
720 end if
721
722 ! Apply clipping
723 if(present(clip_dict)) optimiser%clip_dict = clip_dict
724
725 ! Initialise general optimiser parameters
726 if(present(learning_rate)) optimiser%learning_rate = learning_rate
727
728 ! Initialise learning rate decay
729 if(present(lr_decay)) then
730 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
731 allocate(optimiser%lr_decay, source = lr_decay)
732 else
733 allocate(optimiser%lr_decay, source = base_lr_decay_type())
734 end if
735
736 ! Initialise RMSprop parameters
737 if(present(beta)) optimiser%beta = beta
738 if(present(epsilon)) optimiser%epsilon = epsilon
739
740 ! Initialise gradients
741 if(present(num_params)) then
742 num_params_ = num_params
743 else
744 num_params_ = 1
745 end if
746 call optimiser%init_gradients(num_params_)
747 end function optimiser_setup_rmsprop
748 !###############################################################################
749
750
751 !###############################################################################
752 pure subroutine init_gradients_rmsprop(this, num_params)
753 !! Initialise gradients for RMSprop optimiser
754 implicit none
755
756 ! Arguments
757 class(rmsprop_optimiser_type), intent(inout) :: this
758 !! Instance of the RMSprop optimiser
759 integer, intent(in) :: num_params
760 !! Number of parameters
761
762
763 ! Initialise gradients
764 if(allocated(this%moving_avg)) deallocate(this%moving_avg)
765 allocate(this%moving_avg(num_params), source=0._real32)
766 end subroutine init_gradients_rmsprop
767 !###############################################################################
768
769
770 !###############################################################################
771 pure subroutine minimise_rmsprop(this, param, gradient)
772 !! Apply gradients to parameters to minimise loss using RMSprop optimiser
773 implicit none
774
775 ! Arguments
776 class(rmsprop_optimiser_type), intent(inout) :: this
777 !! Instance of the RMSprop optimiser
778 real(real32), dimension(:), intent(inout) :: param
779 !! Parameters
780 real(real32), dimension(:), intent(inout) :: gradient
781 !! Gradients
782
783 ! Local variables
784 real(real32) :: learning_rate
785 !! Learning rate
786
787
788 ! Decay learning rate and update iteration
789 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
790
791 ! Apply regularisation
792 if(this%regularisation) &
793 call this%regulariser%regularise( param, gradient, learning_rate )
794
795 this%moving_avg = this%beta * this%moving_avg + &
796 (1._real32 - this%beta) * gradient ** 2._real32
797
798 ! Update parameters
799 param = param - learning_rate * gradient / &
800 (sqrt(this%moving_avg + this%epsilon))
801 end subroutine minimise_rmsprop
802 !###############################################################################
803
804
805 !##############################################################################!
806 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
807 !##############################################################################!
808
809
810 !###############################################################################
811 module function optimiser_setup_adagrad( &
812 learning_rate, epsilon, &
813 num_params, regulariser, clip_dict, lr_decay &
814 ) result(optimiser)
815 !! Set up the Adagrad optimiser
816 implicit none
817
818 ! Arguments
819 real(real32), optional, intent(in) :: learning_rate
820 !! Learning rate
821 real(real32), optional, intent(in) :: epsilon
822 !! Epsilon
823 integer, optional, intent(in) :: num_params
824 !! Number of parameters
825 class(base_regulariser_type), optional, intent(in) :: regulariser
826 !! Regularisation method
827 type(clip_type), optional, intent(in) :: clip_dict
828 !! Clipping dictionary
829 class(base_lr_decay_type), optional, intent(in) :: lr_decay
830 !! Learning rate decay method
831
832 type(adagrad_optimiser_type) :: optimiser
833 !! Instance of the Adagrad optimiser
834
835 ! Local variables
836 integer :: num_params_
837 !! Number of parameters
838
839
840 ! Initialise optimiser name
841 optimiser%name = "adagrad"
842
843 ! Apply regularisation
844 if(present(regulariser))then
845 optimiser%regularisation = .true.
846 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
847 allocate(optimiser%regulariser, source = regulariser)
848 end if
849
850 ! Apply clipping
851 if(present(clip_dict)) optimiser%clip_dict = clip_dict
852
853 ! Initialise general optimiser parameters
854 if(present(learning_rate)) optimiser%learning_rate = learning_rate
855
856 ! Initialise learning rate decay
857 if(present(lr_decay)) then
858 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
859 allocate(optimiser%lr_decay, source = lr_decay)
860 else
861 allocate(optimiser%lr_decay, source = base_lr_decay_type())
862 end if
863
864 ! Initialise Adagrad parameters
865 if(present(epsilon)) optimiser%epsilon = epsilon
866
867 ! Initialise gradients
868 if(present(num_params)) then
869 num_params_ = num_params
870 else
871 num_params_ = 1
872 end if
873 call optimiser%init_gradients(num_params_)
874 end function optimiser_setup_adagrad
875 !###############################################################################
876
877
878 !###############################################################################
879 pure subroutine init_gradients_adagrad(this, num_params)
880 !! Initialise gradients for Adagrad optimiser
881 implicit none
882
883 ! Arguments
884 class(adagrad_optimiser_type), intent(inout) :: this
885 !! Instance of the Adagrad optimiser
886 integer, intent(in) :: num_params
887 !! Number of parameters
888
889
890 ! Initialise gradients
891 if(allocated(this%sum_squares)) deallocate(this%sum_squares)
892 allocate(this%sum_squares(num_params), source=0._real32)
893 end subroutine init_gradients_adagrad
894 !###############################################################################
895
896
897 !###############################################################################
898 pure subroutine minimise_adagrad(this, param, gradient)
899 !! Apply gradients to parameters to minimise loss using Adagrad optimiser
900 implicit none
901
902 ! Arguments
903 class(adagrad_optimiser_type), intent(inout) :: this
904 !! Instance of the Adagrad optimiser
905 real(real32), dimension(:), intent(inout) :: param
906 !! Parameters
907 real(real32), dimension(:), intent(inout) :: gradient
908 !! Gradients
909
910 real(real32) :: learning_rate
911 !! Learning rate
912
913
914 ! Decay learning rate and update iteration
915 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
916
917 ! Apply regularisation
918 if(this%regularisation) &
919 call this%regulariser%regularise( param, gradient, learning_rate )
920
921 this%sum_squares = this%sum_squares + gradient ** 2._real32
922
923 ! Update parameters
924 param = param - learning_rate * gradient / &
925 (sqrt(this%sum_squares + this%epsilon))
926 end subroutine minimise_adagrad
927 !###############################################################################
928
929
930 !##############################################################################!
931 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
932 !##############################################################################!
933
934
935 !###############################################################################
936 module function optimiser_setup_adam( &
937 learning_rate, beta1, beta2, epsilon, &
938 num_params, regulariser, clip_dict, lr_decay &
939 ) result(optimiser)
940 !! Set up the Adam optimiser
941 implicit none
942
943 ! Arguments
944 real(real32), optional, intent(in) :: learning_rate
945 !! Learning rate
946 real(real32), optional, intent(in) :: beta1, beta2, epsilon
947 !! Beta1, beta2, and epsilon
948 integer, optional, intent(in) :: num_params
949 !! Number of parameters
950 class(base_regulariser_type), optional, intent(in) :: regulariser
951 !! Regularisation method
952 type(clip_type), optional, intent(in) :: clip_dict
953 !! Clipping dictionary
954 class(base_lr_decay_type), optional, intent(in) :: lr_decay
955 !! Learning rate decay method
956
957 type(adam_optimiser_type) :: optimiser
958 !! Instance of the Adam optimiser
959
960 ! Local variables
961 integer :: num_params_
962 !! Number of parameters
963
964
965 ! Initialise optimiser name
966 optimiser%name = "adam"
967
968 ! Apply regularisation
969 if(present(regulariser))then
970 optimiser%regularisation = .true.
971 if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser)
972 allocate(optimiser%regulariser, source = regulariser)
973 end if
974
975 ! Apply clipping
976 if(present(clip_dict)) optimiser%clip_dict = clip_dict
977
978 ! Initialise general optimiser parameters
979 if(present(learning_rate)) optimiser%learning_rate = learning_rate
980
981 ! Initialise learning rate decay
982 if(present(lr_decay)) then
983 if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay)
984 allocate(optimiser%lr_decay, source = lr_decay)
985 else
986 allocate(optimiser%lr_decay, source = base_lr_decay_type())
987 end if
988
989 ! Initialise Adam parameters
990 if(present(beta1)) optimiser%beta1 = beta1
991 if(present(beta2)) optimiser%beta2 = beta2
992 if(present(epsilon)) optimiser%epsilon = epsilon
993
994 ! Initialise gradients
995 if(present(num_params)) then
996 num_params_ = num_params
997 else
998 num_params_ = 1
999 end if
1000 call optimiser%init_gradients(num_params_)
1001 end function optimiser_setup_adam
1002 !###############################################################################
1003
1004
1005 !###############################################################################
1006 pure subroutine init_gradients_adam(this, num_params)
1007 !! Initialise gradients for Adam optimiser
1008 implicit none
1009
1010 ! Arguments
1011 class(adam_optimiser_type), intent(inout) :: this
1012 !! Instance of the Adam optimiser
1013 integer, intent(in) :: num_params
1014 !! Number of parameters
1015
1016
1017 ! Initialise gradients
1018 if(allocated(this%m)) deallocate(this%m)
1019 if(allocated(this%v)) deallocate(this%v)
1020 allocate(this%m(num_params), source=0._real32)
1021 allocate(this%v(num_params), source=0._real32)
1022 end subroutine init_gradients_adam
1023 !###############################################################################
1024
1025
1026 !###############################################################################
1027 pure subroutine minimise_adam(this, param, gradient)
1028 !! Apply gradients to parameters to minimise loss using Adam optimiser
1029 implicit none
1030
1031 ! Arguments
1032 class(adam_optimiser_type), intent(inout) :: this
1033 !! Instance of the Adam optimiser
1034 real(real32), dimension(:), intent(inout) :: param
1035 !! Parameters
1036 real(real32), dimension(:), intent(inout) :: gradient
1037 !! Gradients
1038
1039 ! Local variables
1040 real(real32) :: learning_rate
1041 !! Learning rate
1042
1043
1044 ! Decay learning rate and update iteration
1045 learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter)
1046
1047 ! Apply regularisation
1048 if(this%regularisation) &
1049 call this%regulariser%regularise( param, gradient, learning_rate )
1050
1051 ! Adaptive learning method
1052 this%m = this%beta1 * this%m + &
1053 (1._real32 - this%beta1) * gradient
1054 this%v = this%beta2 * this%v + &
1055 (1._real32 - this%beta2) * gradient ** 2._real32
1056
1057 ! Update parameters
1058 associate( &
1059 m_hat => this%m / (1._real32 - this%beta1**this%iter), &
1060 v_hat => this%v / (1._real32 - this%beta2**this%iter) )
1061 select type(regulariser => this%regulariser)
1062 type is (l2_regulariser_type)
1063 select case(regulariser%decoupled)
1064 case(.true.)
1065 ! decoupled weight decay (AdamW)
1066 param = param - learning_rate * &
1067 ( &
1068 m_hat / (sqrt(v_hat) + this%epsilon) + &
1069 regulariser%l2 * param &
1070 )
1071 case(.false.)
1072 ! classical L2 regularisation (included in gradient)
1073 param = param - learning_rate * ( &
1074 ( m_hat + regulariser%l2 * param ) / &
1075 ( sqrt(v_hat) + this%epsilon ) &
1076 )
1077 end select
1078 class default
1079 ! no regularisation — standard Adam
1080 param = param - learning_rate * ( &
1081 m_hat / (sqrt(v_hat) + this%epsilon) )
1082 end select
1083 end associate
1084 end subroutine minimise_adam
1085 !###############################################################################
1086
1087 end module athena__optimiser
1088