Line | Branch | Exec | Source |
---|---|---|---|
1 | !!!############################################################################# | ||
2 | !!! Code written by Ned Thaddeus Taylor | ||
3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
4 | !!!############################################################################# | ||
5 | !!! module contains regularisation methods and associated derived types | ||
6 | !!! module contains the following derived types: | ||
7 | !!! - base_regulariser_type - abstract base regulariser type | ||
8 | !!! - l1_regulariser_type - L1 regulariser type | ||
9 | !!! - l2_regulariser_type - L2 regulariser type | ||
10 | !!! - l1l2_regulariser_type - L1L2 regulariser type | ||
11 | !!!################## | ||
12 | !!! the base_regulariser_type contains the following deferred procedure: | ||
13 | !!! - regularise - regularise the gradient | ||
14 | !!!############################################################################# | ||
15 | module regulariser | ||
16 | use constants, only: real12 | ||
17 | implicit none | ||
18 | |||
19 | |||
20 | !!!----------------------------------------------------------------------------- | ||
21 | !!! regularise type | ||
22 | !!!----------------------------------------------------------------------------- | ||
23 | type, abstract :: base_regulariser_type | ||
24 | contains | ||
25 | procedure(regularise), deferred, pass(this) :: regularise | ||
26 | end type base_regulariser_type | ||
27 | |||
28 | abstract interface | ||
29 | pure subroutine regularise(this, params, gradient, learning_rate) | ||
30 | import :: base_regulariser_type, real12 | ||
31 | class(base_regulariser_type), intent(in) :: this | ||
32 | real(real12), dimension(:), intent(in) :: params | ||
33 | real(real12), dimension(:), intent(inout) :: gradient | ||
34 | real(real12), intent(in) :: learning_rate | ||
35 | end subroutine regularise | ||
36 | end interface | ||
37 | |||
38 | !! Lasso regression | ||
39 | !! attempts to prevent overfitting | ||
40 | type, extends(base_regulariser_type) :: l1_regulariser_type | ||
41 | real(real12) :: l1 = 0.01_real12 | ||
42 | contains | ||
43 | procedure, pass(this) :: regularise => regularise_l1 | ||
44 | end type l1_regulariser_type | ||
45 | |||
46 | !! Ridge regression | ||
47 | !! attempts to prevent overfitting | ||
48 | type, extends(base_regulariser_type) :: l2_regulariser_type | ||
49 | !! l2 = L2 regularisation | ||
50 | !! l2_decoupled = decoupled weight decay regularisation (AdamW) | ||
51 | real(real12) :: l2 = 0.01_real12 | ||
52 | real(real12) :: l2_decoupled = 0.01_real12 | ||
53 | logical :: decoupled = .true. | ||
54 | contains | ||
55 | procedure, pass(this) :: regularise => regularise_l2 | ||
56 | end type l2_regulariser_type | ||
57 | |||
58 | type, extends(base_regulariser_type) :: l1l2_regulariser_type | ||
59 | real(real12) :: l1 = 0.01_real12 | ||
60 | real(real12) :: l2 = 0.01_real12 | ||
61 | contains | ||
62 | procedure, pass(this) :: regularise => regularise_l1l2 | ||
63 | end type l1l2_regulariser_type | ||
64 | |||
65 | |||
66 | private | ||
67 | |||
68 | public :: base_regulariser_type | ||
69 | public :: l1_regulariser_type | ||
70 | public :: l2_regulariser_type | ||
71 | public :: l1l2_regulariser_type | ||
72 | |||
73 | |||
74 | contains | ||
75 | |||
76 | !!!############################################################################# | ||
77 | !!! regularise | ||
78 | !!!############################################################################# | ||
79 |
2/4✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
|
5 | pure subroutine regularise_l1(this, params, gradient, learning_rate) |
80 | class(l1_regulariser_type), intent(in) :: this | ||
81 | real(real12), dimension(:), intent(in) :: params | ||
82 | real(real12), dimension(:), intent(inout) :: gradient | ||
83 | real(real12), intent(in) :: learning_rate | ||
84 | |||
85 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✓ Branch 33 taken 41 times.
✓ Branch 34 taken 5 times.
|
46 | gradient = gradient + learning_rate * this%l1 * sign(1._real12,params) |
86 | |||
87 | 5 | end subroutine regularise_l1 | |
88 | !!!----------------------------------------------------------------------------- | ||
89 |
2/4✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
|
3 | pure subroutine regularise_l2(this, params, gradient, learning_rate) |
90 | class(l2_regulariser_type), intent(in) :: this | ||
91 | real(real12), dimension(:), intent(in) :: params | ||
92 | real(real12), dimension(:), intent(inout) :: gradient | ||
93 | real(real12), intent(in) :: learning_rate | ||
94 | |||
95 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✓ Branch 33 taken 21 times.
✓ Branch 34 taken 3 times.
|
24 | gradient = gradient + learning_rate * 2._real12 * this%l2 * params |
96 | |||
97 | 3 | end subroutine regularise_l2 | |
98 | !!!----------------------------------------------------------------------------- | ||
99 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | pure subroutine regularise_l1l2(this, params, gradient, learning_rate) |
100 | class(l1l2_regulariser_type), intent(in) :: this | ||
101 | real(real12), dimension(:), intent(in) :: params | ||
102 | real(real12), dimension(:), intent(inout) :: gradient | ||
103 | real(real12), intent(in) :: learning_rate | ||
104 | |||
105 | ✗ | gradient = gradient + learning_rate * & | |
106 |
17/32✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 38 taken 1 times.
✓ Branch 39 taken 1 times.
|
2 | (this%l1 * sign(1._real12,params) + 2._real12 * this%l2 * params) |
107 | |||
108 | 1 | end subroutine regularise_l1l2 | |
109 | !!!############################################################################# | ||
110 | |||
111 | 28 | end module regulariser | |
112 |