| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__regulariser | ||
| 2 | !! Module containing regularisation methods | ||
| 3 | !! | ||
| 4 | !! This module contains regularisation methods to prevent overfitting | ||
| 5 | !! in neural networks | ||
| 6 | use coreutils, only: real32 | ||
| 7 | implicit none | ||
| 8 | |||
| 9 | |||
| 10 | private | ||
| 11 | |||
| 12 | public :: base_regulariser_type | ||
| 13 | public :: l1_regulariser_type | ||
| 14 | public :: l2_regulariser_type | ||
| 15 | public :: l1l2_regulariser_type | ||
| 16 | |||
| 17 | |||
| 18 | type, abstract :: base_regulariser_type | ||
| 19 | !! Abstract type for regularisation | ||
| 20 | contains | ||
| 21 | procedure(regularise), deferred, pass(this) :: regularise | ||
| 22 | !! Regularisation method | ||
| 23 | end type base_regulariser_type | ||
| 24 | |||
| 25 | abstract interface | ||
| 26 | pure subroutine regularise(this, params, gradient, learning_rate) | ||
| 27 | !! Regularise the parameters | ||
| 28 | import :: base_regulariser_type, real32 | ||
| 29 | class(base_regulariser_type), intent(in) :: this | ||
| 30 | !! Regulariser object | ||
| 31 | real(real32), dimension(:), intent(in) :: params | ||
| 32 | !! Parameters to regularise | ||
| 33 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 34 | !! Gradient of the parameters | ||
| 35 | real(real32), intent(in) :: learning_rate | ||
| 36 | !! Learning rate | ||
| 37 | end subroutine regularise | ||
| 38 | end interface | ||
| 39 | |||
| 40 | type, extends(base_regulariser_type) :: l1_regulariser_type | ||
| 41 | !! Type for L1 regularisation | ||
| 42 | !! | ||
| 43 | !! L1 regularisation is also known as Lasso regression | ||
| 44 | !! It is used to prevent overfitting in neural networks | ||
| 45 | real(real32) :: l1 = 0.01_real32 | ||
| 46 | contains | ||
| 47 | procedure, pass(this) :: regularise => regularise_l1 | ||
| 48 | !! Regularisation method | ||
| 49 | end type l1_regulariser_type | ||
| 50 | |||
| 51 | type, extends(base_regulariser_type) :: l2_regulariser_type | ||
| 52 | !! Type for L2 regularisation | ||
| 53 | !! | ||
| 54 | !! L2 regularisation is also known as Ridge regression | ||
| 55 | !! It is used to prevent overfitting in neural networks | ||
| 56 | !! L2 = L2 regularisation | ||
| 57 | !! L2_decoupled = decoupled weight decay regularisation (AdamW) | ||
| 58 | real(real32) :: l2 = 0.01_real32 | ||
| 59 | !! Regularisation parameter | ||
| 60 | real(real32) :: l2_decoupled = 0.01_real32 | ||
| 61 | !! Decoupled weight decay regularisation parameter | ||
| 62 | logical :: decoupled = .true. | ||
| 63 | !! Use decoupled weight decay regularisation | ||
| 64 | contains | ||
| 65 | procedure, pass(this) :: regularise => regularise_l2 | ||
| 66 | !! Regularisation method | ||
| 67 | end type l2_regulariser_type | ||
| 68 | |||
| 69 | type, extends(base_regulariser_type) :: l1l2_regulariser_type | ||
| 70 | !! Type for L1 and L2 regularisation | ||
| 71 | real(real32) :: l1 = 0.01_real32 | ||
| 72 | !! L1 regularisation parameter | ||
| 73 | real(real32) :: l2 = 0.01_real32 | ||
| 74 | !! L2 regularisation parameter | ||
| 75 | contains | ||
| 76 | procedure, pass(this) :: regularise => regularise_l1l2 | ||
| 77 | !! Regularisation method | ||
| 78 | end type l1l2_regulariser_type | ||
| 79 | |||
| 80 | |||
| 81 | |||
| 82 | contains | ||
| 83 | |||
| 84 | !############################################################################### | ||
| 85 | − | pure subroutine regularise_l1(this, params, gradient, learning_rate) | |
| 86 | !! Regularise the parameters using L1 regularisation | ||
| 87 | implicit none | ||
| 88 | |||
| 89 | ! Arguments | ||
| 90 | class(l1_regulariser_type), intent(in) :: this | ||
| 91 | !! Instance of the L1 regulariser | ||
| 92 | real(real32), dimension(:), intent(in) :: params | ||
| 93 | !! Parameters to regularise | ||
| 94 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 95 | !! Gradient of the parameters | ||
| 96 | real(real32), intent(in) :: learning_rate | ||
| 97 | !! Learning rate | ||
| 98 | |||
| 99 | − | gradient = gradient + learning_rate * this%l1 * sign(1._real32,params) | |
| 100 | |||
| 101 | − | end subroutine regularise_l1 | |
| 102 | !------------------------------------------------------------------------------- | ||
| 103 | − | pure subroutine regularise_l2(this, params, gradient, learning_rate) | |
| 104 | !! Regularise the parameters using L2 regularisation | ||
| 105 | implicit none | ||
| 106 | |||
| 107 | ! Arguments | ||
| 108 | class(l2_regulariser_type), intent(in) :: this | ||
| 109 | !! Instance of the L2 regulariser | ||
| 110 | real(real32), dimension(:), intent(in) :: params | ||
| 111 | !! Parameters to regularise | ||
| 112 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 113 | !! Gradient of the parameters | ||
| 114 | real(real32), intent(in) :: learning_rate | ||
| 115 | !! Learning rate | ||
| 116 | |||
| 117 | − | gradient = gradient + learning_rate * 2._real32 * this%l2 * params | |
| 118 | |||
| 119 | − | end subroutine regularise_l2 | |
| 120 | !------------------------------------------------------------------------------- | ||
| 121 | − | pure subroutine regularise_l1l2(this, params, gradient, learning_rate) | |
| 122 | !! Regularise the parameters using L1 and L2 regularisation | ||
| 123 | implicit none | ||
| 124 | |||
| 125 | ! Arguments | ||
| 126 | class(l1l2_regulariser_type), intent(in) :: this | ||
| 127 | !! Instance of the L1 and L2 regulariser | ||
| 128 | real(real32), dimension(:), intent(in) :: params | ||
| 129 | !! Parameters to regularise | ||
| 130 | real(real32), dimension(:), intent(inout) :: gradient | ||
| 131 | !! Gradient of the parameters | ||
| 132 | real(real32), intent(in) :: learning_rate | ||
| 133 | !! Learning rate | ||
| 134 | |||
| 135 | − | gradient = gradient + learning_rate * & | |
| 136 | − | (this%l1 * sign(1._real32,params) + 2._real32 * this%l2 * params) | |
| 137 | |||
| 138 | − | end subroutine regularise_l1l2 | |
| 139 | !############################################################################### | ||
| 140 | |||
| 141 | − | end module athena__regulariser | |
| 142 |