| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | !!!############################################################################# | ||
| 2 | !!! Code written by Ned Thaddeus Taylor | ||
| 3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
| 4 | !!!############################################################################# | ||
| 5 | !!! module contains regularisation methods and associated derived types | ||
| 6 | !!! module contains the following derived types: | ||
| 7 | !!! - base_regulariser_type - abstract base regulariser type | ||
| 8 | !!! - l1_regulariser_type - L1 regulariser type | ||
| 9 | !!! - l2_regulariser_type - L2 regulariser type | ||
| 10 | !!! - l1l2_regulariser_type - L1L2 regulariser type | ||
| 11 | !!!################## | ||
| 12 | !!! the base_regulariser_type contains the following deferred procedure: | ||
| 13 | !!! - regularise - regularise the gradient | ||
| 14 | !!!############################################################################# | ||
| 15 | module regulariser | ||
| 16 | use constants, only: real12 | ||
| 17 | implicit none | ||
| 18 | |||
| 19 | |||
| 20 | !!!----------------------------------------------------------------------------- | ||
| 21 | !!! regularise type | ||
| 22 | !!!----------------------------------------------------------------------------- | ||
| 23 | type, abstract :: base_regulariser_type | ||
| 24 | contains | ||
| 25 | procedure(regularise), deferred, pass(this) :: regularise | ||
| 26 | end type base_regulariser_type | ||
| 27 | |||
| 28 | abstract interface | ||
| 29 | pure subroutine regularise(this, params, gradient, learning_rate) | ||
| 30 | import :: base_regulariser_type, real12 | ||
| 31 | class(base_regulariser_type), intent(in) :: this | ||
| 32 | real(real12), dimension(:), intent(in) :: params | ||
| 33 | real(real12), dimension(:), intent(inout) :: gradient | ||
| 34 | real(real12), intent(in) :: learning_rate | ||
| 35 | end subroutine regularise | ||
| 36 | end interface | ||
| 37 | |||
| 38 | !! Lasso regression | ||
| 39 | !! attempts to prevent overfitting | ||
| 40 | type, extends(base_regulariser_type) :: l1_regulariser_type | ||
| 41 | real(real12) :: l1 = 0.01_real12 | ||
| 42 | contains | ||
| 43 | procedure, pass(this) :: regularise => regularise_l1 | ||
| 44 | end type l1_regulariser_type | ||
| 45 | |||
| 46 | !! Ridge regression | ||
| 47 | !! attempts to prevent overfitting | ||
| 48 | type, extends(base_regulariser_type) :: l2_regulariser_type | ||
| 49 | !! l2 = L2 regularisation | ||
| 50 | !! l2_decoupled = decoupled weight decay regularisation (AdamW) | ||
| 51 | real(real12) :: l2 = 0.01_real12 | ||
| 52 | real(real12) :: l2_decoupled = 0.01_real12 | ||
| 53 | logical :: decoupled = .true. | ||
| 54 | contains | ||
| 55 | procedure, pass(this) :: regularise => regularise_l2 | ||
| 56 | end type l2_regulariser_type | ||
| 57 | |||
| 58 | type, extends(base_regulariser_type) :: l1l2_regulariser_type | ||
| 59 | real(real12) :: l1 = 0.01_real12 | ||
| 60 | real(real12) :: l2 = 0.01_real12 | ||
| 61 | contains | ||
| 62 | procedure, pass(this) :: regularise => regularise_l1l2 | ||
| 63 | end type l1l2_regulariser_type | ||
| 64 | |||
| 65 | |||
| 66 | private | ||
| 67 | |||
| 68 | public :: base_regulariser_type | ||
| 69 | public :: l1_regulariser_type | ||
| 70 | public :: l2_regulariser_type | ||
| 71 | public :: l1l2_regulariser_type | ||
| 72 | |||
| 73 | |||
| 74 | contains | ||
| 75 | |||
| 76 | !!!############################################################################# | ||
| 77 | !!! regularise | ||
| 78 | !!!############################################################################# | ||
| 79 |
2/4✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
|
5 | pure subroutine regularise_l1(this, params, gradient, learning_rate) |
| 80 | class(l1_regulariser_type), intent(in) :: this | ||
| 81 | real(real12), dimension(:), intent(in) :: params | ||
| 82 | real(real12), dimension(:), intent(inout) :: gradient | ||
| 83 | real(real12), intent(in) :: learning_rate | ||
| 84 | |||
| 85 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✓ Branch 33 taken 41 times.
✓ Branch 34 taken 5 times.
|
46 | gradient = gradient + learning_rate * this%l1 * sign(1._real12,params) |
| 86 | |||
| 87 | 5 | end subroutine regularise_l1 | |
| 88 | !!!----------------------------------------------------------------------------- | ||
| 89 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | pure subroutine regularise_l2(this, params, gradient, learning_rate) |
| 90 | class(l2_regulariser_type), intent(in) :: this | ||
| 91 | real(real12), dimension(:), intent(in) :: params | ||
| 92 | real(real12), dimension(:), intent(inout) :: gradient | ||
| 93 | real(real12), intent(in) :: learning_rate | ||
| 94 | |||
| 95 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✓ Branch 33 taken 21 times.
✓ Branch 34 taken 3 times.
|
24 | gradient = gradient + learning_rate * 2._real12 * this%l2 * params |
| 96 | |||
| 97 | 3 | end subroutine regularise_l2 | |
| 98 | !!!----------------------------------------------------------------------------- | ||
| 99 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | pure subroutine regularise_l1l2(this, params, gradient, learning_rate) |
| 100 | class(l1l2_regulariser_type), intent(in) :: this | ||
| 101 | real(real12), dimension(:), intent(in) :: params | ||
| 102 | real(real12), dimension(:), intent(inout) :: gradient | ||
| 103 | real(real12), intent(in) :: learning_rate | ||
| 104 | |||
| 105 | ✗ | gradient = gradient + learning_rate * & | |
| 106 |
17/32✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 38 taken 1 times.
✓ Branch 39 taken 1 times.
|
2 | (this%l1 * sign(1._real12,params) + 2._real12 * this%l2 * params) |
| 107 | |||
| 108 | 1 | end subroutine regularise_l1l2 | |
| 109 | !!!############################################################################# | ||
| 110 | |||
| 111 | 28 | end module regulariser | |
| 112 |