GCC Code Coverage Report


Directory: src/lib/
File: src/lib/mod_regulariser.f90
Date: 2024-06-28 12:51:18
Exec Total Coverage
Lines: 10 11 90.9%
Functions: 0 0 -%
Branches: 49 92 53.3%

Line Branch Exec Source
1 !!!#############################################################################
2 !!! Code written by Ned Thaddeus Taylor
3 !!! Code part of the ATHENA library - a feedforward neural network library
4 !!!#############################################################################
5 !!! module contains regularisation methods and associated derived types
6 !!! module contains the following derived types:
7 !!! - base_regulariser_type - abstract base regulariser type
8 !!! - l1_regulariser_type - L1 regulariser type
9 !!! - l2_regulariser_type - L2 regulariser type
10 !!! - l1l2_regulariser_type - L1L2 regulariser type
11 !!!##################
12 !!! the base_regulariser_type contains the following deferred procedure:
13 !!! - regularise - regularise the gradient
14 !!!#############################################################################
15 module regulariser
16 use constants, only: real12
17 implicit none
18
19
20 !!!-----------------------------------------------------------------------------
21 !!! regularise type
22 !!!-----------------------------------------------------------------------------
23 type, abstract :: base_regulariser_type
24 contains
25 procedure(regularise), deferred, pass(this) :: regularise
26 end type base_regulariser_type
27
28 abstract interface
29 pure subroutine regularise(this, params, gradient, learning_rate)
30 import :: base_regulariser_type, real12
31 class(base_regulariser_type), intent(in) :: this
32 real(real12), dimension(:), intent(in) :: params
33 real(real12), dimension(:), intent(inout) :: gradient
34 real(real12), intent(in) :: learning_rate
35 end subroutine regularise
36 end interface
37
38 !! Lasso regression
39 !! attempts to prevent overfitting
40 type, extends(base_regulariser_type) :: l1_regulariser_type
41 real(real12) :: l1 = 0.01_real12
42 contains
43 procedure, pass(this) :: regularise => regularise_l1
44 end type l1_regulariser_type
45
46 !! Ridge regression
47 !! attempts to prevent overfitting
48 type, extends(base_regulariser_type) :: l2_regulariser_type
49 !! l2 = L2 regularisation
50 !! l2_decoupled = decoupled weight decay regularisation (AdamW)
51 real(real12) :: l2 = 0.01_real12
52 real(real12) :: l2_decoupled = 0.01_real12
53 logical :: decoupled = .true.
54 contains
55 procedure, pass(this) :: regularise => regularise_l2
56 end type l2_regulariser_type
57
58 type, extends(base_regulariser_type) :: l1l2_regulariser_type
59 real(real12) :: l1 = 0.01_real12
60 real(real12) :: l2 = 0.01_real12
61 contains
62 procedure, pass(this) :: regularise => regularise_l1l2
63 end type l1l2_regulariser_type
64
65
66 private
67
68 public :: base_regulariser_type
69 public :: l1_regulariser_type
70 public :: l2_regulariser_type
71 public :: l1l2_regulariser_type
72
73
74 contains
75
76 !!!#############################################################################
77 !!! regularise
78 !!!#############################################################################
79
2/4
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
5 pure subroutine regularise_l1(this, params, gradient, learning_rate)
80 class(l1_regulariser_type), intent(in) :: this
81 real(real12), dimension(:), intent(in) :: params
82 real(real12), dimension(:), intent(inout) :: gradient
83 real(real12), intent(in) :: learning_rate
84
85
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✓ Branch 33 taken 41 times.
✓ Branch 34 taken 5 times.
46 gradient = gradient + learning_rate * this%l1 * sign(1._real12,params)
86
87 5 end subroutine regularise_l1
88 !!!-----------------------------------------------------------------------------
89
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 pure subroutine regularise_l2(this, params, gradient, learning_rate)
90 class(l2_regulariser_type), intent(in) :: this
91 real(real12), dimension(:), intent(in) :: params
92 real(real12), dimension(:), intent(inout) :: gradient
93 real(real12), intent(in) :: learning_rate
94
95
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✓ Branch 33 taken 21 times.
✓ Branch 34 taken 3 times.
24 gradient = gradient + learning_rate * 2._real12 * this%l2 * params
96
97 3 end subroutine regularise_l2
98 !!!-----------------------------------------------------------------------------
99
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 pure subroutine regularise_l1l2(this, params, gradient, learning_rate)
100 class(l1l2_regulariser_type), intent(in) :: this
101 real(real12), dimension(:), intent(in) :: params
102 real(real12), dimension(:), intent(inout) :: gradient
103 real(real12), intent(in) :: learning_rate
104
105 gradient = gradient + learning_rate * &
106
17/32
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 38 taken 1 times.
✓ Branch 39 taken 1 times.
2 (this%l1 * sign(1._real12,params) + 2._real12 * this%l2 * params)
107
108 1 end subroutine regularise_l1l2
109 !!!#############################################################################
110
111 28 end module regulariser
112