| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__optimiser | ||
| 2 | !! Module containing implementations of optimisation methods | ||
| 3 | !! | ||
| 4 | !! This module implements gradient-based optimisers for training neural networks | ||
| 5 | !! by minimizing loss functions through iterative parameter updates. | ||
| 6 | !! | ||
| 7 | !! Implemented optimisers: | ||
| 8 | !! | ||
| 9 | !! SGD (Stochastic Gradient Descent): | ||
| 10 | !! θ_{t+1} = θ_t - η * ∇L(θ_t) | ||
| 11 | !! Simple, reliable baseline optimiser | ||
| 12 | !! | ||
| 13 | !! SGD with Momentum: | ||
| 14 | !! v_{t+1} = μ*v_t + ∇L(θ_t) | ||
| 15 | !! θ_{t+1} = θ_t - η * v_{t+1} | ||
| 16 | !! Accelerates convergence, dampens oscillations | ||
| 17 | !! | ||
| 18 | !! RMSProp: | ||
| 19 | !! s_{t+1} = β*s_t + (1-β)*[∇L(θ_t)]² | ||
| 20 | !! θ_{t+1} = θ_t - η * ∇L(θ_t) / sqrt(s_{t+1} + ε) | ||
| 21 | !! Adapts learning rate per parameter, good for non-stationary objectives | ||
| 22 | !! | ||
| 23 | !! Adagrad: | ||
| 24 | !! s_{t+1} = s_t + [∇L(θ_t)]² | ||
| 25 | !! θ_{t+1} = θ_t - η * ∇L(θ_t) / sqrt(s_{t+1} + ε) | ||
| 26 | !! Adapts learning rate based on historical gradients | ||
| 27 | !! | ||
| 28 | !! Adam (Adaptive Moment Estimation): | ||
| 29 | !! m_{t+1} = β₁*m_t + (1-β₁)*∇L(θ_t) [first moment] | ||
| 30 | !! v_{t+1} = β₂*v_t + (1-β₂)*[∇L(θ_t)]² [second moment] | ||
| 31 | !! m̂ = m_{t+1}/(1-β₁^t), v̂ = v_{t+1}/(1-β₂^t) [bias correction] | ||
| 32 | !! θ_{t+1} = θ_t - η * m̂ / (sqrt(v̂) + ε) | ||
| 33 | !! Combines momentum and adaptive learning rates, most popular choice | ||
| 34 | !! | ||
| 35 | !! L-BFGS (Limited-memory BFGS): | ||
| 36 | !! Quasi-Newton method approximating Hessian inverse | ||
| 37 | !! Good for small-medium sized problems, smooth objectives | ||
| 38 | !! | ||
| 39 | !! where η is learning rate, μ is momentum, β/β₁/β₂ are decay rates | ||
| 40 | !! | ||
| 41 | !! Attribution statement: | ||
| 42 | !! The following module is based on code from the neural-fortran library | ||
| 43 | !! https://github.com/modern-fortran/neural-fortran | ||
| 44 | !! The implementation of optimiser_base_type, sgd_optimiser_type, | ||
| 45 | !! rmsprop_optimiser_type, adagrad_optimiser_type, and adam_optimiser_type | ||
| 46 | !! are based on the corresponding types from neural-fortran | ||
| 47 | use coreutils, only: real32, stop_program | ||
| 48 | use athena__clipper, only: clip_type | ||
| 49 | use athena__regulariser, only: base_regulariser_type, l2_regulariser_type | ||
| 50 | use athena__learning_rate_decay, only: base_lr_decay_type | ||
| 51 | implicit none | ||
| 52 | |||
| 53 | |||
| 54 | private | ||
| 55 | |||
| 56 | public :: base_optimiser_type | ||
| 57 | public :: sgd_optimiser_type | ||
| 58 | public :: rmsprop_optimiser_type | ||
| 59 | public :: adagrad_optimiser_type | ||
| 60 | public :: adam_optimiser_type | ||
| 61 | |||
| 62 | |||
| 63 | !------------------------------------------------------------------------------- | ||
| 64 | |||
| 65 | type :: base_optimiser_type | ||
| 66 | !! Base optimiser type | ||
| 67 | character(len=20) :: name | ||
| 68 | !! Name of the optimiser | ||
| 69 | integer :: iter = 0 | ||
| 70 | !! Iteration number | ||
| 71 | integer :: epoch = 0 | ||
| 72 | !! Epoch number | ||
| 73 | real(real32) :: learning_rate = 0.01_real32 | ||
| 74 | !! Learning rate hyperparameter | ||
| 75 | logical :: regularisation = .false. | ||
| 76 | !! Apply regularisation | ||
| 77 | class(base_regulariser_type), allocatable :: regulariser | ||
| 78 | !! Regularisation method | ||
| 79 | class(base_lr_decay_type), allocatable :: lr_decay | ||
| 80 | !! Learning rate decay method | ||
| 81 | type(clip_type) :: clip_dict | ||
| 82 | !! Clipping dictionary | ||
| 83 | contains | ||
| 84 | procedure, pass(this) :: init => init_base | ||
| 85 | !! Initialise base optimiser | ||
| 86 | procedure, pass(this) :: print_to_unit => print_to_unit_base | ||
| 87 | !! Print base optimiser information | ||
| 88 | procedure, pass(this) :: read => read_base | ||
| 89 | !! Read base optimiser information | ||
| 90 | procedure, pass(this) :: init_gradients => init_gradients_base | ||
| 91 | !! Initialise gradients | ||
| 92 | procedure, pass(this) :: minimise => minimise_base | ||
| 93 | !! Apply gradients to parameters to minimise loss using base optimiser | ||
| 94 | end type base_optimiser_type | ||
| 95 | |||
| 96 | interface base_optimiser_type | ||
| 97 | !! Interface for setting up the base optimiser | ||
| 98 | module function optimiser_setup_base( & | ||
| 99 | learning_rate, num_params, & | ||
| 100 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
| 101 | !! Set up the base optimiser | ||
| 102 | real(real32), optional, intent(in) :: learning_rate | ||
| 103 | !! Learning rate | ||
| 104 | integer, optional, intent(in) :: num_params | ||
| 105 | !! Number of parameters | ||
| 106 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 107 | !! Regularisation method | ||
| 108 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 109 | !! Clipping dictionary | ||
| 110 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 111 | !! Learning rate decay method | ||
| 112 | type(base_optimiser_type) :: optimiser | ||
| 113 | !! Instance of the base optimiser | ||
| 114 | end function optimiser_setup_base | ||
| 115 | end interface base_optimiser_type | ||
| 116 | |||
| 117 | !------------------------------------------------------------------------------- | ||
| 118 | |||
| 119 | type, extends(base_optimiser_type) :: sgd_optimiser_type | ||
| 120 | !! Stochastic gradient descent optimiser type | ||
| 121 | logical :: nesterov = .false. | ||
| 122 | !! Nesterov momentum | ||
| 123 | real(real32) :: momentum = 0._real32 | ||
| 124 | !! Fraction of momentum-based learning | ||
| 125 | real(real32), allocatable, dimension(:) :: velocity | ||
| 126 | !! Velocity for momentum | ||
| 127 | contains | ||
| 128 | procedure, pass(this) :: init_gradients => init_gradients_sgd | ||
| 129 | !! Initialise gradients for SGD | ||
| 130 | procedure, pass(this) :: minimise => minimise_sgd | ||
| 131 | !! Apply gradients to parameters to minimise loss using SGD optimiser | ||
| 132 | end type sgd_optimiser_type | ||
| 133 | |||
| 134 | interface sgd_optimiser_type | ||
| 135 | !! Interface for setting up the SGD optimiser | ||
| 136 | module function optimiser_setup_sgd( & | ||
| 137 | learning_rate, momentum, & | ||
| 138 | nesterov, num_params, & | ||
| 139 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
| 140 | !! Set up the SGD optimiser | ||
| 141 | real(real32), optional, intent(in) :: learning_rate, momentum | ||
| 142 | !! Learning rate and momentum | ||
| 143 | logical, optional, intent(in) :: nesterov | ||
| 144 | !! Nesterov momentum | ||
| 145 | integer, optional, intent(in) :: num_params | ||
| 146 | !! Number of parameters | ||
| 147 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 148 | !! Regularisation method | ||
| 149 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 150 | !! Clipping dictionary | ||
| 151 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 152 | !! Learning rate decay method | ||
| 153 | type(sgd_optimiser_type) :: optimiser | ||
| 154 | !! Instance of the SGD optimiser | ||
| 155 | end function optimiser_setup_sgd | ||
| 156 | end interface sgd_optimiser_type | ||
| 157 | |||
| 158 | !------------------------------------------------------------------------------- | ||
| 159 | |||
| 160 | type, extends(base_optimiser_type) :: rmsprop_optimiser_type | ||
| 161 | !! RMSprop optimiser type | ||
| 162 | real(real32) :: beta = 0._real32 | ||
| 163 | !! Beta parameter | ||
| 164 | real(real32) :: epsilon = 1.E-8_real32 | ||
| 165 | !! Epsilon parameter | ||
| 166 | real(real32), allocatable, dimension(:) :: moving_avg | ||
| 167 | !! Moving average | ||
| 168 | contains | ||
| 169 | procedure, pass(this) :: init_gradients => init_gradients_rmsprop | ||
| 170 | !! Initialise gradients for RMSprop | ||
| 171 | procedure, pass(this) :: minimise => minimise_rmsprop | ||
| 172 | !! Apply gradients to parameters to minimise loss using RMSprop optimiser | ||
| 173 | end type rmsprop_optimiser_type | ||
| 174 | |||
| 175 | interface rmsprop_optimiser_type | ||
| 176 | !! Interface for setting up the RMSprop optimiser | ||
| 177 | module function optimiser_setup_rmsprop( & | ||
| 178 | learning_rate, beta, & | ||
| 179 | epsilon, num_params, & | ||
| 180 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
| 181 | !! Set up the RMSprop optimiser | ||
| 182 | real(real32), optional, intent(in) :: learning_rate, beta, epsilon | ||
| 183 | !! Learning rate, beta, and epsilon | ||
| 184 | integer, optional, intent(in) :: num_params | ||
| 185 | !! Number of parameters | ||
| 186 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 187 | !! Regularisation method | ||
| 188 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 189 | !! Clipping dictionary | ||
| 190 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 191 | !! Learning rate decay method | ||
| 192 | type(rmsprop_optimiser_type) :: optimiser | ||
| 193 | !! Instance of the RMSprop optimiser | ||
| 194 | end function optimiser_setup_rmsprop | ||
| 195 | end interface rmsprop_optimiser_type | ||
| 196 | |||
| 197 | !------------------------------------------------------------------------------- | ||
| 198 | |||
| 199 | type, extends(base_optimiser_type) :: adagrad_optimiser_type | ||
| 200 | !! Adagrad optimiser type | ||
| 201 | real(real32) :: epsilon = 1.E-8_real32 | ||
| 202 | !! Epsilon parameter | ||
| 203 | real(real32), allocatable, dimension(:) :: sum_squares | ||
| 204 | !! Sum of squares of gradients | ||
| 205 | contains | ||
| 206 | procedure, pass(this) :: init_gradients => init_gradients_adagrad | ||
| 207 | !! Initialise gradients for Adagrad | ||
| 208 | procedure, pass(this) :: minimise => minimise_adagrad | ||
| 209 | !! Apply gradients to parameters to minimise loss using Adagrad optimiser | ||
| 210 | end type adagrad_optimiser_type | ||
| 211 | |||
| 212 | interface adagrad_optimiser_type | ||
| 213 | !! Interface for setting up the Adagrad optimiser | ||
| 214 | module function optimiser_setup_adagrad( & | ||
| 215 | learning_rate, & | ||
| 216 | epsilon, num_params, & | ||
| 217 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
| 218 | !! Set up the Adagrad optimiser | ||
| 219 | real(real32), optional, intent(in) :: learning_rate, epsilon | ||
| 220 | !! Learning rate and epsilon | ||
| 221 | integer, optional, intent(in) :: num_params | ||
| 222 | !! Number of parameters | ||
| 223 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 224 | !! Regularisation method | ||
| 225 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 226 | !! Clipping dictionary | ||
| 227 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 228 | !! Learning rate decay method | ||
| 229 | type(adagrad_optimiser_type) :: optimiser | ||
| 230 | !! Instance of the Adagrad optimiser | ||
| 231 | end function optimiser_setup_adagrad | ||
| 232 | end interface adagrad_optimiser_type | ||
| 233 | |||
| 234 | !------------------------------------------------------------------------------- | ||
| 235 | |||
| 236 | type, extends(base_optimiser_type) :: adam_optimiser_type | ||
| 237 | !! Adam optimiser type | ||
| 238 | real(real32) :: beta1 = 0.9_real32 | ||
| 239 | !! Beta1 parameter | ||
| 240 | real(real32) :: beta2 = 0.999_real32 | ||
| 241 | !! Beta2 parameter | ||
| 242 | real(real32) :: epsilon = 1.E-8_real32 | ||
| 243 | !! Epsilon parameter | ||
| 244 | real(real32), allocatable, dimension(:) :: m | ||
| 245 | !! First moment estimate | ||
| 246 | real(real32), allocatable, dimension(:) :: v | ||
| 247 | !! Second moment estimate | ||
| 248 | contains | ||
| 249 | procedure, pass(this) :: init_gradients => init_gradients_adam | ||
| 250 | !! Initialise gradients for Adam | ||
| 251 | procedure, pass(this) :: minimise => minimise_adam | ||
| 252 | !! Apply gradients to parameters to minimise loss using Adam optimiser | ||
| 253 | end type adam_optimiser_type | ||
| 254 | |||
| 255 | interface adam_optimiser_type | ||
| 256 | !! Interface for setting up the Adam optimiser | ||
| 257 | module function optimiser_setup_adam( & | ||
| 258 | learning_rate, & | ||
| 259 | beta1, beta2, epsilon, & | ||
| 260 | num_params, & | ||
| 261 | regulariser, clip_dict, lr_decay) result(optimiser) | ||
| 262 | !! Set up the Adam optimiser | ||
| 263 | real(real32), optional, intent(in) :: learning_rate | ||
| 264 | !! Learning rate | ||
| 265 | real(real32), optional, intent(in) :: beta1, beta2, epsilon | ||
| 266 | !! Beta1, beta2, and epsilon | ||
| 267 | integer, optional, intent(in) :: num_params | ||
| 268 | !! Number of parameters | ||
| 269 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 270 | !! Regularisation method | ||
| 271 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 272 | !! Clipping dictionary | ||
| 273 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 274 | !! Learning rate decay method | ||
| 275 | type(adam_optimiser_type) :: optimiser | ||
| 276 | !! Instance of the Adam optimiser | ||
| 277 | end function optimiser_setup_adam | ||
| 278 | end interface adam_optimiser_type | ||
| 279 | |||
| 280 | |||
| 281 | |||
| 282 | contains | ||
| 283 | |||
| 284 | !############################################################################### | ||
| 285 | − | module function optimiser_setup_base( & | |
| 286 | learning_rate, num_params, & | ||
| 287 | regulariser, clip_dict, lr_decay & | ||
| 288 | − | ) result(optimiser) | |
| 289 | !! Set up the base optimiser | ||
| 290 | implicit none | ||
| 291 | |||
| 292 | ! Arguments | ||
| 293 | real(real32), optional, intent(in) :: learning_rate | ||
| 294 | !! Learning rate | ||
| 295 | integer, optional, intent(in) :: num_params | ||
| 296 | !! Number of parameters | ||
| 297 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 298 | !! Regularisation method | ||
| 299 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 300 | !! Clipping dictionary | ||
| 301 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 302 | !! Learning rate decay method | ||
| 303 | |||
| 304 | type(base_optimiser_type) :: optimiser | ||
| 305 | !! Instance of the base optimiser | ||
| 306 | |||
| 307 | ! Local variables | ||
| 308 | integer :: num_params_ | ||
| 309 | !! Number of parameters | ||
| 310 | |||
| 311 | |||
| 312 | ! Initialise optimiser name | ||
| 313 | − | optimiser%name = "base" | |
| 314 | |||
| 315 | ! Apply regularisation | ||
| 316 | − | if(present(regulariser))then | |
| 317 | − | optimiser%regularisation = .true. | |
| 318 | − | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) | |
| 319 | − | allocate(optimiser%regulariser, source = regulariser) | |
| 320 | end if | ||
| 321 | |||
| 322 | ! Apply clipping | ||
| 323 | − | if(present(clip_dict)) optimiser%clip_dict = clip_dict | |
| 324 | |||
| 325 | ! Initialise general optimiser parameters | ||
| 326 | − | if(present(learning_rate)) optimiser%learning_rate = learning_rate | |
| 327 | |||
| 328 | ! Initialise learning rate decay | ||
| 329 | − | if(present(lr_decay)) then | |
| 330 | − | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) | |
| 331 | − | allocate(optimiser%lr_decay, source = lr_decay) | |
| 332 | else | ||
| 333 | − | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
| 334 | end if | ||
| 335 | |||
| 336 | ! Initialise gradients | ||
| 337 | − | if(present(num_params)) then | |
| 338 | − | num_params_ = num_params | |
| 339 | else | ||
| 340 | − | num_params_ = 1 | |
| 341 | end if | ||
| 342 | − | call optimiser%init_gradients(num_params_) | |
| 343 | − | end function optimiser_setup_base | |
| 344 | !############################################################################### | ||
| 345 | |||
| 346 | |||
| 347 | !############################################################################### | ||
| 348 | − | subroutine init_base(this, num_params, regulariser, clip_dict) | |
| 349 | !! Initialise base optimiser | ||
| 350 | implicit none | ||
| 351 | |||
| 352 | ! Arguments | ||
| 353 | class(base_optimiser_type), intent(inout) :: this | ||
| 354 | !! Instance of the base optimiser | ||
| 355 | integer, intent(in) :: num_params | ||
| 356 | !! Number of parameters | ||
| 357 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 358 | !! Regularisation method | ||
| 359 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 360 | !! Clipping dictionary | ||
| 361 | |||
| 362 | |||
| 363 | ! Apply regularisation | ||
| 364 | − | if(present(regulariser))then | |
| 365 | − | this%regularisation = .true. | |
| 366 | − | if(allocated(this%regulariser)) deallocate(this%regulariser) | |
| 367 | − | allocate(this%regulariser, source = regulariser) | |
| 368 | end if | ||
| 369 | |||
| 370 | ! Apply clipping | ||
| 371 | − | if(present(clip_dict)) this%clip_dict = clip_dict | |
| 372 | |||
| 373 | ! Initialise gradients | ||
| 374 | − | call this%init_gradients(num_params) | |
| 375 | − | end subroutine init_base | |
| 376 | !############################################################################### | ||
| 377 | |||
| 378 | |||
| 379 | !############################################################################### | ||
| 380 | − | pure subroutine init_gradients_base(this, num_params) | |
| 381 | !! Initialise gradients for base optimiser | ||
| 382 | implicit none | ||
| 383 | |||
| 384 | ! Arguments | ||
| 385 | class(base_optimiser_type), intent(inout) :: this | ||
| 386 | !! Instance of the base optimiser | ||
| 387 | integer, intent(in) :: num_params | ||
| 388 | !! Number of parameters | ||
| 389 | |||
| 390 | !allocate(this%velocity(num_params), source=0._real32) | ||
| 391 | − | end subroutine init_gradients_base | |
| 392 | !############################################################################### | ||
| 393 | |||
| 394 | |||
| 395 | !############################################################################### | ||
| 396 | − | pure subroutine minimise_base(this, param, gradient) | |
| 397 | !! Apply gradients to parameters to minimise loss using base optimiser | ||
| 398 | implicit none | ||
| 399 | |||
| 400 | ! Arguments | ||
| 401 | class(base_optimiser_type), intent(inout) :: this | ||
| 402 | !! Instance of the base optimiser | ||
| 403 | real(real32), dimension(:), intent(inout) :: param | ||
| 404 | !! Parameters | ||
| 405 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 406 | !! Gradients | ||
| 407 | |||
| 408 | ! Local variables | ||
| 409 | real(real32) :: learning_rate | ||
| 410 | !! Learning rate | ||
| 411 | |||
| 412 | |||
| 413 | ! Decay learning rate and update iteration | ||
| 414 | − | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
| 415 | |||
| 416 | ! Update parameters | ||
| 417 | − | param = param - learning_rate * gradient | |
| 418 | − | end subroutine minimise_base | |
| 419 | !############################################################################### | ||
| 420 | |||
| 421 | |||
| 422 | !##############################################################################! | ||
| 423 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 424 | !##############################################################################! | ||
| 425 | |||
| 426 | |||
| 427 | !############################################################################### | ||
| 428 | − | subroutine print_to_unit_base(this, unit) | |
| 429 | !! Print base optimiser information | ||
| 430 | implicit none | ||
| 431 | |||
| 432 | ! Arguments | ||
| 433 | class(base_optimiser_type), intent(in) :: this | ||
| 434 | !! Instance of the base optimiser | ||
| 435 | integer, intent(in) :: unit | ||
| 436 | !! File unit | ||
| 437 | |||
| 438 | |||
| 439 | − | write(unit,'(6X,"NAME = ",A)') this%name | |
| 440 | − | write(unit,'(6X,"LEARNING_RATE = ",F10.5)') this%learning_rate | |
| 441 | − | write(unit,'(6X,"ITERATION = ",I10)') this%iter | |
| 442 | − | write(unit,'(6X,"EPOCH = ",I10)') this%epoch | |
| 443 | − | write(unit,'(6X,"REGULARISATION = ",L1)') this%regularisation | |
| 444 | |||
| 445 | − | end subroutine print_to_unit_base | |
| 446 | !############################################################################### | ||
| 447 | |||
| 448 | |||
| 449 | !############################################################################### | ||
| 450 | − | subroutine read_base(this, unit) | |
| 451 | !! Read base optimiser information | ||
| 452 | use athena__tools_infile, only: assign_val, assign_vec | ||
| 453 | use coreutils, only: to_lower, to_upper, icount | ||
| 454 | implicit none | ||
| 455 | |||
| 456 | ! Arguments | ||
| 457 | class(base_optimiser_type), intent(inout) :: this | ||
| 458 | !! Instance of the base optimiser | ||
| 459 | integer, intent(in) :: unit | ||
| 460 | !! File unit | ||
| 461 | |||
| 462 | ! Local variables | ||
| 463 | integer :: stat | ||
| 464 | !! File status | ||
| 465 | integer :: itmp1 | ||
| 466 | !! Temporary integer | ||
| 467 | character(256) :: buffer, tag, err_msg | ||
| 468 | !! Buffer for reading lines, tag for identifying lines, error message | ||
| 469 | |||
| 470 | |||
| 471 | ! Loop over tags in layer card | ||
| 472 | !--------------------------------------------------------------------------- | ||
| 473 | − | tag_loop: do | |
| 474 | |||
| 475 | ! Check for end of file | ||
| 476 | !------------------------------------------------------------------------ | ||
| 477 | − | read(unit,'(A)',iostat=stat) buffer | |
| 478 | − | if(stat.ne.0)then | |
| 479 | write(err_msg,'("file encountered error (EoF?) before END ",A)') & | ||
| 480 | − | to_upper(this%name) | |
| 481 | − | call stop_program(err_msg) | |
| 482 | − | return | |
| 483 | end if | ||
| 484 | − | if(trim(adjustl(buffer)).eq."") cycle tag_loop | |
| 485 | |||
| 486 | ! Check for end of layer card | ||
| 487 | !------------------------------------------------------------------------ | ||
| 488 | − | if(trim(adjustl(buffer)).eq."END OPTIMISER")then | |
| 489 | − | backspace(unit) | |
| 490 | − | exit tag_loop | |
| 491 | end if | ||
| 492 | |||
| 493 | − | tag=trim(adjustl(buffer)) | |
| 494 | − | if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1)) | |
| 495 | |||
| 496 | ! Read parameters from save file | ||
| 497 | !------------------------------------------------------------------------ | ||
| 498 | − | select case(trim(tag)) | |
| 499 | case("NAME") | ||
| 500 | − | call assign_val(buffer, this%name, itmp1) | |
| 501 | case("LEARNING_RATE") | ||
| 502 | − | call assign_val(buffer, this%learning_rate, itmp1) | |
| 503 | case("ITERATION") | ||
| 504 | − | call assign_val(buffer, this%iter, itmp1) | |
| 505 | case("EPOCH") | ||
| 506 | − | call assign_val(buffer, this%epoch, itmp1) | |
| 507 | case("REGULARISATION") | ||
| 508 | − | call assign_val(buffer, this%regularisation, itmp1) | |
| 509 | case default | ||
| 510 | ! Don't look for "e" due to scientific notation of numbers | ||
| 511 | ! ... i.e. exponent (E+00) | ||
| 512 | − | if(scan(to_lower(trim(adjustl(buffer))),& | |
| 513 | 'abcdfghijklmnopqrstuvwxyz').eq.0)then | ||
| 514 | − | cycle tag_loop | |
| 515 | − | elseif(tag(:3).eq.'END')then | |
| 516 | − | cycle tag_loop | |
| 517 | end if | ||
| 518 | write(err_msg,'("Unrecognised line in input file: ",A)') & | ||
| 519 | − | trim(adjustl(buffer)) | |
| 520 | − | call stop_program(err_msg) | |
| 521 | − | return | |
| 522 | end select | ||
| 523 | end do tag_loop | ||
| 524 | |||
| 525 | |||
| 526 | ! Check for end of layer card | ||
| 527 | !--------------------------------------------------------------------------- | ||
| 528 | − | read(unit,'(A)') buffer | |
| 529 | − | if(trim(adjustl(buffer)).ne."END OPTIMISER")then | |
| 530 | − | write(0,*) trim(adjustl(buffer)) | |
| 531 | − | write(err_msg,'("END OPTIMISER not where expected")') | |
| 532 | − | call stop_program(err_msg) | |
| 533 | − | return | |
| 534 | end if | ||
| 535 | |||
| 536 | end subroutine read_base | ||
| 537 | !################################################################################ | ||
| 538 | |||
| 539 | |||
| 540 | !##############################################################################! | ||
| 541 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 542 | !##############################################################################! | ||
| 543 | |||
| 544 | |||
| 545 | !############################################################################### | ||
| 546 | − | module function optimiser_setup_sgd( & | |
| 547 | learning_rate, momentum, & | ||
| 548 | nesterov, num_params, & | ||
| 549 | − | regulariser, clip_dict, lr_decay) result(optimiser) | |
| 550 | !! Set up the SGD optimiser | ||
| 551 | implicit none | ||
| 552 | |||
| 553 | ! Arguments | ||
| 554 | real(real32), optional, intent(in) :: learning_rate, momentum | ||
| 555 | !! Learning rate and momentum | ||
| 556 | logical, optional, intent(in) :: nesterov | ||
| 557 | !! Nesterov momentum | ||
| 558 | integer, optional, intent(in) :: num_params | ||
| 559 | !! Number of parameters | ||
| 560 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 561 | !! Regularisation method | ||
| 562 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 563 | !! Clipping dictionary | ||
| 564 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 565 | !! Learning rate decay method | ||
| 566 | |||
| 567 | type(sgd_optimiser_type) :: optimiser | ||
| 568 | !! Instance of the SGD optimiser | ||
| 569 | |||
| 570 | ! Local variables | ||
| 571 | integer :: num_params_ | ||
| 572 | !! Number of parameters | ||
| 573 | |||
| 574 | |||
| 575 | ! Initialise optimiser name | ||
| 576 | − | optimiser%name = "sgd" | |
| 577 | |||
| 578 | ! Apply regularisation | ||
| 579 | − | if(present(regulariser))then | |
| 580 | − | optimiser%regularisation = .true. | |
| 581 | − | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) | |
| 582 | − | allocate(optimiser%regulariser, source = regulariser) | |
| 583 | end if | ||
| 584 | |||
| 585 | ! Apply clipping | ||
| 586 | − | if(present(clip_dict)) optimiser%clip_dict = clip_dict | |
| 587 | |||
| 588 | ! Initialise general optimiser parameters | ||
| 589 | − | if(present(learning_rate)) optimiser%learning_rate = learning_rate | |
| 590 | − | if(present(momentum)) optimiser%momentum = momentum | |
| 591 | |||
| 592 | ! Initialise learning rate decay | ||
| 593 | − | if(present(lr_decay)) then | |
| 594 | − | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) | |
| 595 | − | allocate(optimiser%lr_decay, source = lr_decay) | |
| 596 | else | ||
| 597 | − | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
| 598 | end if | ||
| 599 | |||
| 600 | ! Initialise nesterov boolean | ||
| 601 | − | if(present(nesterov)) optimiser%nesterov = nesterov | |
| 602 | |||
| 603 | ! Initialise gradients | ||
| 604 | − | if(present(num_params)) then | |
| 605 | − | num_params_ = num_params | |
| 606 | else | ||
| 607 | − | num_params_ = 1 | |
| 608 | end if | ||
| 609 | − | call optimiser%init_gradients(num_params_) | |
| 610 | − | end function optimiser_setup_sgd | |
| 611 | !############################################################################### | ||
| 612 | |||
| 613 | |||
| 614 | !############################################################################### | ||
| 615 | − | pure subroutine init_gradients_sgd(this, num_params) | |
| 616 | !! Initialise gradients for SGD optimiser | ||
| 617 | implicit none | ||
| 618 | |||
| 619 | ! Arguments | ||
| 620 | class(sgd_optimiser_type), intent(inout) :: this | ||
| 621 | !! Instance of the SGD optimiser | ||
| 622 | integer, intent(in) :: num_params | ||
| 623 | !! Number of parameters | ||
| 624 | |||
| 625 | |||
| 626 | ! Initialise gradients | ||
| 627 | − | if(allocated(this%velocity)) deallocate(this%velocity) | |
| 628 | − | allocate(this%velocity(num_params), source=0._real32) | |
| 629 | − | end subroutine init_gradients_sgd | |
| 630 | !############################################################################### | ||
| 631 | |||
| 632 | |||
| 633 | !############################################################################### | ||
| 634 | − | pure subroutine minimise_sgd(this, param, gradient) | |
| 635 | !! Apply gradients to parameters to minimise loss using SGD optimiser | ||
| 636 | implicit none | ||
| 637 | |||
| 638 | ! Arguments | ||
| 639 | class(sgd_optimiser_type), intent(inout) :: this | ||
| 640 | !! Instance of the SGD optimiser | ||
| 641 | real(real32), dimension(:), intent(inout) :: param | ||
| 642 | !! Parameters | ||
| 643 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 644 | !! Gradients | ||
| 645 | |||
| 646 | ! Local variables | ||
| 647 | real(real32) :: learning_rate | ||
| 648 | !! Learning rate | ||
| 649 | |||
| 650 | |||
| 651 | ! Decay learning rate and update iteration | ||
| 652 | − | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
| 653 | |||
| 654 | ! Apply regularisation | ||
| 655 | − | if(this%regularisation) & | |
| 656 | − | call this%regulariser%regularise( param, gradient, learning_rate ) | |
| 657 | |||
| 658 | − | gradient = - learning_rate * gradient | |
| 659 | ! Update parameters | ||
| 660 | − | if(this%momentum.gt.1.E-8_real32)then | |
| 661 | !! Adaptive learning method | ||
| 662 | − | this%velocity = this%momentum * this%velocity + gradient | |
| 663 | − | if(this%nesterov)then | |
| 664 | − | param = param + this%momentum * this%velocity + gradient | |
| 665 | else | ||
| 666 | − | param = param + this%velocity | |
| 667 | end if | ||
| 668 | else | ||
| 669 | ! Standard learning method | ||
| 670 | − | this%velocity = gradient | |
| 671 | − | param = param + this%velocity | |
| 672 | end if | ||
| 673 | − | end subroutine minimise_sgd | |
| 674 | !############################################################################### | ||
| 675 | |||
| 676 | |||
| 677 | !##############################################################################! | ||
| 678 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 679 | !##############################################################################! | ||
| 680 | |||
| 681 | |||
| 682 | !############################################################################### | ||
| 683 | − | module function optimiser_setup_rmsprop( & | |
| 684 | learning_rate, beta, epsilon, & | ||
| 685 | num_params, regulariser, clip_dict, lr_decay & | ||
| 686 | − | ) result(optimiser) | |
| 687 | !! Set up the RMSprop optimiser | ||
| 688 | implicit none | ||
| 689 | |||
| 690 | ! Arguments | ||
| 691 | real(real32), optional, intent(in) :: learning_rate | ||
| 692 | !! Learning rate | ||
| 693 | real(real32), optional, intent(in) :: beta, epsilon | ||
| 694 | !! Beta and epsilon | ||
| 695 | integer, optional, intent(in) :: num_params | ||
| 696 | !! Number of parameters | ||
| 697 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 698 | !! Regularisation method | ||
| 699 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 700 | !! Clipping dictionary | ||
| 701 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 702 | !! Learning rate decay method | ||
| 703 | |||
| 704 | type(rmsprop_optimiser_type) :: optimiser | ||
| 705 | !! Instance of the RMSprop optimiser | ||
| 706 | |||
| 707 | ! Local variables | ||
| 708 | integer :: num_params_ | ||
| 709 | !! Number of parameters | ||
| 710 | |||
| 711 | |||
| 712 | ! Initialise optimiser name | ||
| 713 | − | optimiser%name = "rmsprop" | |
| 714 | |||
| 715 | ! Apply regularisation | ||
| 716 | − | if(present(regulariser))then | |
| 717 | − | optimiser%regularisation = .true. | |
| 718 | − | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) | |
| 719 | − | allocate(optimiser%regulariser, source = regulariser) | |
| 720 | end if | ||
| 721 | |||
| 722 | ! Apply clipping | ||
| 723 | − | if(present(clip_dict)) optimiser%clip_dict = clip_dict | |
| 724 | |||
| 725 | ! Initialise general optimiser parameters | ||
| 726 | − | if(present(learning_rate)) optimiser%learning_rate = learning_rate | |
| 727 | |||
| 728 | ! Initialise learning rate decay | ||
| 729 | − | if(present(lr_decay)) then | |
| 730 | − | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) | |
| 731 | − | allocate(optimiser%lr_decay, source = lr_decay) | |
| 732 | else | ||
| 733 | − | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
| 734 | end if | ||
| 735 | |||
| 736 | ! Initialise RMSprop parameters | ||
| 737 | − | if(present(beta)) optimiser%beta = beta | |
| 738 | − | if(present(epsilon)) optimiser%epsilon = epsilon | |
| 739 | |||
| 740 | ! Initialise gradients | ||
| 741 | − | if(present(num_params)) then | |
| 742 | − | num_params_ = num_params | |
| 743 | else | ||
| 744 | − | num_params_ = 1 | |
| 745 | end if | ||
| 746 | − | call optimiser%init_gradients(num_params_) | |
| 747 | − | end function optimiser_setup_rmsprop | |
| 748 | !############################################################################### | ||
| 749 | |||
| 750 | |||
| 751 | !############################################################################### | ||
| 752 | − | pure subroutine init_gradients_rmsprop(this, num_params) | |
| 753 | !! Initialise gradients for RMSprop optimiser | ||
| 754 | implicit none | ||
| 755 | |||
| 756 | ! Arguments | ||
| 757 | class(rmsprop_optimiser_type), intent(inout) :: this | ||
| 758 | !! Instance of the RMSprop optimiser | ||
| 759 | integer, intent(in) :: num_params | ||
| 760 | !! Number of parameters | ||
| 761 | |||
| 762 | |||
| 763 | ! Initialise gradients | ||
| 764 | − | if(allocated(this%moving_avg)) deallocate(this%moving_avg) | |
| 765 | − | allocate(this%moving_avg(num_params), source=0._real32) | |
| 766 | − | end subroutine init_gradients_rmsprop | |
| 767 | !############################################################################### | ||
| 768 | |||
| 769 | |||
| 770 | !############################################################################### | ||
| 771 | − | pure subroutine minimise_rmsprop(this, param, gradient) | |
| 772 | !! Apply gradients to parameters to minimise loss using RMSprop optimiser | ||
| 773 | implicit none | ||
| 774 | |||
| 775 | ! Arguments | ||
| 776 | class(rmsprop_optimiser_type), intent(inout) :: this | ||
| 777 | !! Instance of the RMSprop optimiser | ||
| 778 | real(real32), dimension(:), intent(inout) :: param | ||
| 779 | !! Parameters | ||
| 780 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 781 | !! Gradients | ||
| 782 | |||
| 783 | ! Local variables | ||
| 784 | real(real32) :: learning_rate | ||
| 785 | !! Learning rate | ||
| 786 | |||
| 787 | |||
| 788 | ! Decay learning rate and update iteration | ||
| 789 | − | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
| 790 | |||
| 791 | ! Apply regularisation | ||
| 792 | − | if(this%regularisation) & | |
| 793 | − | call this%regulariser%regularise( param, gradient, learning_rate ) | |
| 794 | |||
| 795 | − | this%moving_avg = this%beta * this%moving_avg + & | |
| 796 | − | (1._real32 - this%beta) * gradient ** 2._real32 | |
| 797 | |||
| 798 | ! Update parameters | ||
| 799 | − | param = param - learning_rate * gradient / & | |
| 800 | − | (sqrt(this%moving_avg + this%epsilon)) | |
| 801 | − | end subroutine minimise_rmsprop | |
| 802 | !############################################################################### | ||
| 803 | |||
| 804 | |||
| 805 | !##############################################################################! | ||
| 806 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 807 | !##############################################################################! | ||
| 808 | |||
| 809 | |||
| 810 | !############################################################################### | ||
| 811 | − | module function optimiser_setup_adagrad( & | |
| 812 | learning_rate, epsilon, & | ||
| 813 | num_params, regulariser, clip_dict, lr_decay & | ||
| 814 | − | ) result(optimiser) | |
| 815 | !! Set up the Adagrad optimiser | ||
| 816 | implicit none | ||
| 817 | |||
| 818 | ! Arguments | ||
| 819 | real(real32), optional, intent(in) :: learning_rate | ||
| 820 | !! Learning rate | ||
| 821 | real(real32), optional, intent(in) :: epsilon | ||
| 822 | !! Epsilon | ||
| 823 | integer, optional, intent(in) :: num_params | ||
| 824 | !! Number of parameters | ||
| 825 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 826 | !! Regularisation method | ||
| 827 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 828 | !! Clipping dictionary | ||
| 829 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 830 | !! Learning rate decay method | ||
| 831 | |||
| 832 | type(adagrad_optimiser_type) :: optimiser | ||
| 833 | !! Instance of the Adagrad optimiser | ||
| 834 | |||
| 835 | ! Local variables | ||
| 836 | integer :: num_params_ | ||
| 837 | !! Number of parameters | ||
| 838 | |||
| 839 | |||
| 840 | ! Initialise optimiser name | ||
| 841 | − | optimiser%name = "adagrad" | |
| 842 | |||
| 843 | ! Apply regularisation | ||
| 844 | − | if(present(regulariser))then | |
| 845 | − | optimiser%regularisation = .true. | |
| 846 | − | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) | |
| 847 | − | allocate(optimiser%regulariser, source = regulariser) | |
| 848 | end if | ||
| 849 | |||
| 850 | ! Apply clipping | ||
| 851 | − | if(present(clip_dict)) optimiser%clip_dict = clip_dict | |
| 852 | |||
| 853 | ! Initialise general optimiser parameters | ||
| 854 | − | if(present(learning_rate)) optimiser%learning_rate = learning_rate | |
| 855 | |||
| 856 | ! Initialise learning rate decay | ||
| 857 | − | if(present(lr_decay)) then | |
| 858 | − | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) | |
| 859 | − | allocate(optimiser%lr_decay, source = lr_decay) | |
| 860 | else | ||
| 861 | − | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
| 862 | end if | ||
| 863 | |||
| 864 | ! Initialise Adagrad parameters | ||
| 865 | − | if(present(epsilon)) optimiser%epsilon = epsilon | |
| 866 | |||
| 867 | ! Initialise gradients | ||
| 868 | − | if(present(num_params)) then | |
| 869 | − | num_params_ = num_params | |
| 870 | else | ||
| 871 | − | num_params_ = 1 | |
| 872 | end if | ||
| 873 | − | call optimiser%init_gradients(num_params_) | |
| 874 | − | end function optimiser_setup_adagrad | |
| 875 | !############################################################################### | ||
| 876 | |||
| 877 | |||
| 878 | !############################################################################### | ||
| 879 | − | pure subroutine init_gradients_adagrad(this, num_params) | |
| 880 | !! Initialise gradients for Adagrad optimiser | ||
| 881 | implicit none | ||
| 882 | |||
| 883 | ! Arguments | ||
| 884 | class(adagrad_optimiser_type), intent(inout) :: this | ||
| 885 | !! Instance of the Adagrad optimiser | ||
| 886 | integer, intent(in) :: num_params | ||
| 887 | !! Number of parameters | ||
| 888 | |||
| 889 | |||
| 890 | ! Initialise gradients | ||
| 891 | − | if(allocated(this%sum_squares)) deallocate(this%sum_squares) | |
| 892 | − | allocate(this%sum_squares(num_params), source=0._real32) | |
| 893 | − | end subroutine init_gradients_adagrad | |
| 894 | !############################################################################### | ||
| 895 | |||
| 896 | |||
| 897 | !############################################################################### | ||
| 898 | − | pure subroutine minimise_adagrad(this, param, gradient) | |
| 899 | !! Apply gradients to parameters to minimise loss using Adagrad optimiser | ||
| 900 | implicit none | ||
| 901 | |||
| 902 | ! Arguments | ||
| 903 | class(adagrad_optimiser_type), intent(inout) :: this | ||
| 904 | !! Instance of the Adagrad optimiser | ||
| 905 | real(real32), dimension(:), intent(inout) :: param | ||
| 906 | !! Parameters | ||
| 907 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 908 | !! Gradients | ||
| 909 | |||
| 910 | real(real32) :: learning_rate | ||
| 911 | !! Learning rate | ||
| 912 | |||
| 913 | |||
| 914 | ! Decay learning rate and update iteration | ||
| 915 | − | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
| 916 | |||
| 917 | ! Apply regularisation | ||
| 918 | − | if(this%regularisation) & | |
| 919 | − | call this%regulariser%regularise( param, gradient, learning_rate ) | |
| 920 | |||
| 921 | − | this%sum_squares = this%sum_squares + gradient ** 2._real32 | |
| 922 | |||
| 923 | ! Update parameters | ||
| 924 | − | param = param - learning_rate * gradient / & | |
| 925 | − | (sqrt(this%sum_squares + this%epsilon)) | |
| 926 | − | end subroutine minimise_adagrad | |
| 927 | !############################################################################### | ||
| 928 | |||
| 929 | |||
| 930 | !##############################################################################! | ||
| 931 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 932 | !##############################################################################! | ||
| 933 | |||
| 934 | |||
| 935 | !############################################################################### | ||
| 936 | − | module function optimiser_setup_adam( & | |
| 937 | learning_rate, beta1, beta2, epsilon, & | ||
| 938 | num_params, regulariser, clip_dict, lr_decay & | ||
| 939 | − | ) result(optimiser) | |
| 940 | !! Set up the Adam optimiser | ||
| 941 | implicit none | ||
| 942 | |||
| 943 | ! Arguments | ||
| 944 | real(real32), optional, intent(in) :: learning_rate | ||
| 945 | !! Learning rate | ||
| 946 | real(real32), optional, intent(in) :: beta1, beta2, epsilon | ||
| 947 | !! Beta1, beta2, and epsilon | ||
| 948 | integer, optional, intent(in) :: num_params | ||
| 949 | !! Number of parameters | ||
| 950 | class(base_regulariser_type), optional, intent(in) :: regulariser | ||
| 951 | !! Regularisation method | ||
| 952 | type(clip_type), optional, intent(in) :: clip_dict | ||
| 953 | !! Clipping dictionary | ||
| 954 | class(base_lr_decay_type), optional, intent(in) :: lr_decay | ||
| 955 | !! Learning rate decay method | ||
| 956 | |||
| 957 | type(adam_optimiser_type) :: optimiser | ||
| 958 | !! Instance of the Adam optimiser | ||
| 959 | |||
| 960 | ! Local variables | ||
| 961 | integer :: num_params_ | ||
| 962 | !! Number of parameters | ||
| 963 | |||
| 964 | |||
| 965 | ! Initialise optimiser name | ||
| 966 | − | optimiser%name = "adam" | |
| 967 | |||
| 968 | ! Apply regularisation | ||
| 969 | − | if(present(regulariser))then | |
| 970 | − | optimiser%regularisation = .true. | |
| 971 | − | if(allocated(optimiser%regulariser)) deallocate(optimiser%regulariser) | |
| 972 | − | allocate(optimiser%regulariser, source = regulariser) | |
| 973 | end if | ||
| 974 | |||
| 975 | ! Apply clipping | ||
| 976 | − | if(present(clip_dict)) optimiser%clip_dict = clip_dict | |
| 977 | |||
| 978 | ! Initialise general optimiser parameters | ||
| 979 | − | if(present(learning_rate)) optimiser%learning_rate = learning_rate | |
| 980 | |||
| 981 | ! Initialise learning rate decay | ||
| 982 | − | if(present(lr_decay)) then | |
| 983 | − | if(allocated(optimiser%lr_decay)) deallocate(optimiser%lr_decay) | |
| 984 | − | allocate(optimiser%lr_decay, source = lr_decay) | |
| 985 | else | ||
| 986 | − | allocate(optimiser%lr_decay, source = base_lr_decay_type()) | |
| 987 | end if | ||
| 988 | |||
| 989 | ! Initialise Adam parameters | ||
| 990 | − | if(present(beta1)) optimiser%beta1 = beta1 | |
| 991 | − | if(present(beta2)) optimiser%beta2 = beta2 | |
| 992 | − | if(present(epsilon)) optimiser%epsilon = epsilon | |
| 993 | |||
| 994 | ! Initialise gradients | ||
| 995 | − | if(present(num_params)) then | |
| 996 | − | num_params_ = num_params | |
| 997 | else | ||
| 998 | − | num_params_ = 1 | |
| 999 | end if | ||
| 1000 | − | call optimiser%init_gradients(num_params_) | |
| 1001 | − | end function optimiser_setup_adam | |
| 1002 | !############################################################################### | ||
| 1003 | |||
| 1004 | |||
| 1005 | !############################################################################### | ||
| 1006 | − | pure subroutine init_gradients_adam(this, num_params) | |
| 1007 | !! Initialise gradients for Adam optimiser | ||
| 1008 | implicit none | ||
| 1009 | |||
| 1010 | ! Arguments | ||
| 1011 | class(adam_optimiser_type), intent(inout) :: this | ||
| 1012 | !! Instance of the Adam optimiser | ||
| 1013 | integer, intent(in) :: num_params | ||
| 1014 | !! Number of parameters | ||
| 1015 | |||
| 1016 | |||
| 1017 | ! Initialise gradients | ||
| 1018 | − | if(allocated(this%m)) deallocate(this%m) | |
| 1019 | − | if(allocated(this%v)) deallocate(this%v) | |
| 1020 | − | allocate(this%m(num_params), source=0._real32) | |
| 1021 | − | allocate(this%v(num_params), source=0._real32) | |
| 1022 | − | end subroutine init_gradients_adam | |
| 1023 | !############################################################################### | ||
| 1024 | |||
| 1025 | |||
| 1026 | !############################################################################### | ||
| 1027 | − | pure subroutine minimise_adam(this, param, gradient) | |
| 1028 | !! Apply gradients to parameters to minimise loss using Adam optimiser | ||
| 1029 | implicit none | ||
| 1030 | |||
| 1031 | ! Arguments | ||
| 1032 | class(adam_optimiser_type), intent(inout) :: this | ||
| 1033 | !! Instance of the Adam optimiser | ||
| 1034 | real(real32), dimension(:), intent(inout) :: param | ||
| 1035 | !! Parameters | ||
| 1036 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 1037 | !! Gradients | ||
| 1038 | |||
| 1039 | ! Local variables | ||
| 1040 | real(real32) :: learning_rate | ||
| 1041 | !! Learning rate | ||
| 1042 | |||
| 1043 | |||
| 1044 | ! Decay learning rate and update iteration | ||
| 1045 | − | learning_rate = this%lr_decay%get_lr(this%learning_rate, this%iter) | |
| 1046 | |||
| 1047 | ! Apply regularisation | ||
| 1048 | − | if(this%regularisation) & | |
| 1049 | − | call this%regulariser%regularise( param, gradient, learning_rate ) | |
| 1050 | |||
| 1051 | ! Adaptive learning method | ||
| 1052 | − | this%m = this%beta1 * this%m + & | |
| 1053 | − | (1._real32 - this%beta1) * gradient | |
| 1054 | − | this%v = this%beta2 * this%v + & | |
| 1055 | − | (1._real32 - this%beta2) * gradient ** 2._real32 | |
| 1056 | |||
| 1057 | ! Update parameters | ||
| 1058 | associate( & | ||
| 1059 | − | m_hat => this%m / (1._real32 - this%beta1**this%iter), & | |
| 1060 | − | v_hat => this%v / (1._real32 - this%beta2**this%iter) ) | |
| 1061 | − | select type(regulariser => this%regulariser) | |
| 1062 | type is (l2_regulariser_type) | ||
| 1063 | − | select case(regulariser%decoupled) | |
| 1064 | case(.true.) | ||
| 1065 | ! decoupled weight decay (AdamW) | ||
| 1066 | − | param = param - learning_rate * & | |
| 1067 | ( & | ||
| 1068 | − | m_hat / (sqrt(v_hat) + this%epsilon) + & | |
| 1069 | − | regulariser%l2 * param & | |
| 1070 | − | ) | |
| 1071 | case(.false.) | ||
| 1072 | ! classical L2 regularisation (included in gradient) | ||
| 1073 | − | param = param - learning_rate * ( & | |
| 1074 | − | ( m_hat + regulariser%l2 * param ) / & | |
| 1075 | − | ( sqrt(v_hat) + this%epsilon ) & | |
| 1076 | − | ) | |
| 1077 | end select | ||
| 1078 | class default | ||
| 1079 | ! no regularisation — standard Adam | ||
| 1080 | − | param = param - learning_rate * ( & | |
| 1081 | − | m_hat / (sqrt(v_hat) + this%epsilon) ) | |
| 1082 | end select | ||
| 1083 | end associate | ||
| 1084 | − | end subroutine minimise_adam | |
| 1085 | !############################################################################### | ||
| 1086 | |||
| 1087 | − | end module athena__optimiser | |
| 1088 |