GCC Code Coverage Report


Directory: src/athena/
File: athena_regulariser.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__regulariser
2 !! Module containing regularisation methods
3 !!
4 !! This module contains regularisation methods to prevent overfitting
5 !! in neural networks
6 use coreutils, only: real32
7 implicit none
8
9
10 private
11
12 public :: base_regulariser_type
13 public :: l1_regulariser_type
14 public :: l2_regulariser_type
15 public :: l1l2_regulariser_type
16
17
18 type, abstract :: base_regulariser_type
19 !! Abstract type for regularisation
20 contains
21 procedure(regularise), deferred, pass(this) :: regularise
22 !! Regularisation method
23 end type base_regulariser_type
24
25 abstract interface
26 pure subroutine regularise(this, params, gradient, learning_rate)
27 !! Regularise the parameters
28 import :: base_regulariser_type, real32
29 class(base_regulariser_type), intent(in) :: this
30 !! Regulariser object
31 real(real32), dimension(:), intent(in) :: params
32 !! Parameters to regularise
33 real(real32), dimension(:), intent(inout) :: gradient
34 !! Gradient of the parameters
35 real(real32), intent(in) :: learning_rate
36 !! Learning rate
37 end subroutine regularise
38 end interface
39
40 type, extends(base_regulariser_type) :: l1_regulariser_type
41 !! Type for L1 regularisation
42 !!
43 !! L1 regularisation is also known as Lasso regression
44 !! It is used to prevent overfitting in neural networks
45 real(real32) :: l1 = 0.01_real32
46 contains
47 procedure, pass(this) :: regularise => regularise_l1
48 !! Regularisation method
49 end type l1_regulariser_type
50
51 type, extends(base_regulariser_type) :: l2_regulariser_type
52 !! Type for L2 regularisation
53 !!
54 !! L2 regularisation is also known as Ridge regression
55 !! It is used to prevent overfitting in neural networks
56 !! L2 = L2 regularisation
57 !! L2_decoupled = decoupled weight decay regularisation (AdamW)
58 real(real32) :: l2 = 0.01_real32
59 !! Regularisation parameter
60 real(real32) :: l2_decoupled = 0.01_real32
61 !! Decoupled weight decay regularisation parameter
62 logical :: decoupled = .true.
63 !! Use decoupled weight decay regularisation
64 contains
65 procedure, pass(this) :: regularise => regularise_l2
66 !! Regularisation method
67 end type l2_regulariser_type
68
69 type, extends(base_regulariser_type) :: l1l2_regulariser_type
70 !! Type for L1 and L2 regularisation
71 real(real32) :: l1 = 0.01_real32
72 !! L1 regularisation parameter
73 real(real32) :: l2 = 0.01_real32
74 !! L2 regularisation parameter
75 contains
76 procedure, pass(this) :: regularise => regularise_l1l2
77 !! Regularisation method
78 end type l1l2_regulariser_type
79
80
81
82 contains
83
84 !###############################################################################
85 pure subroutine regularise_l1(this, params, gradient, learning_rate)
86 !! Regularise the parameters using L1 regularisation
87 implicit none
88
89 ! Arguments
90 class(l1_regulariser_type), intent(in) :: this
91 !! Instance of the L1 regulariser
92 real(real32), dimension(:), intent(in) :: params
93 !! Parameters to regularise
94 real(real32), dimension(:), intent(inout) :: gradient
95 !! Gradient of the parameters
96 real(real32), intent(in) :: learning_rate
97 !! Learning rate
98
99 gradient = gradient + learning_rate * this%l1 * sign(1._real32,params)
100
101 end subroutine regularise_l1
102 !-------------------------------------------------------------------------------
103 pure subroutine regularise_l2(this, params, gradient, learning_rate)
104 !! Regularise the parameters using L2 regularisation
105 implicit none
106
107 ! Arguments
108 class(l2_regulariser_type), intent(in) :: this
109 !! Instance of the L2 regulariser
110 real(real32), dimension(:), intent(in) :: params
111 !! Parameters to regularise
112 real(real32), dimension(:), intent(inout) :: gradient
113 !! Gradient of the parameters
114 real(real32), intent(in) :: learning_rate
115 !! Learning rate
116
117 gradient = gradient + learning_rate * 2._real32 * this%l2 * params
118
119 end subroutine regularise_l2
120 !-------------------------------------------------------------------------------
121 pure subroutine regularise_l1l2(this, params, gradient, learning_rate)
122 !! Regularise the parameters using L1 and L2 regularisation
123 implicit none
124
125 ! Arguments
126 class(l1l2_regulariser_type), intent(in) :: this
127 !! Instance of the L1 and L2 regulariser
128 real(real32), dimension(:), intent(in) :: params
129 !! Parameters to regularise
130 real(real32), dimension(:), intent(inout) :: gradient
131 !! Gradient of the parameters
132 real(real32), intent(in) :: learning_rate
133 !! Learning rate
134
135 gradient = gradient + learning_rate * &
136 (this%l1 * sign(1._real32,params) + 2._real32 * this%l2 * params)
137
138 end subroutine regularise_l1l2
139 !###############################################################################
140
141 end module athena__regulariser
142