GCC Code Coverage Report


Directory: src/lib/
File: src/lib/mod_clipper.f90
Date: 2024-06-28 12:51:18
Exec Total Coverage
Lines: 54 61 88.5%
Functions: 0 0 -%
Branches: 87 174 50.0%

Line Branch Exec Source
1 !!!#############################################################################
2 !!! Code written by Ned Thaddeus Taylor
3 !!! Code part of the ATHENA library - a feedforward neural network library
4 !!!#############################################################################
5 !!! module contains routines for clipping gradients
6 !!! module includes the following types:
7 !!! - clip_type - type containing clipping information
8 !!!##################
9 !!! clip_type contains the following procedures:
10 !!! - read_clip - read clipping information from strings
11 !!! - set_clip - set clipping information from a dictionary
12 !!! - apply_clip - apply clipping to gradients
13 !!!#############################################################################
14 module clipper
15 use constants, only: real12
16 implicit none
17
18
19 !!!------------------------------------------------------------------------
20 !!! gradient clipping type
21 !!!------------------------------------------------------------------------
22 type clip_type
23 logical :: l_min_max = .false.
24 logical :: l_norm = .false.
25 real(real12) :: min =-huge(1._real12)
26 real(real12) :: max = huge(1._real12)
27 real(real12) :: norm = huge(1._real12)
28 contains
29 procedure, pass(this) :: read => read_clip
30 procedure, pass(this) :: set => set_clip
31 procedure, pass(this) :: apply => apply_clip
32 end type clip_type
33
34 interface clip_type
35 module function clip_setup( &
36 clip_min, clip_max, clip_norm) result(clip)
37 real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm
38 type(clip_type) :: clip
39 end function clip_setup
40 end interface clip_type
41
42
43
44 private
45
46 public :: clip_type
47
48
49 contains
50
51 !!!#############################################################################
52 !!! set clip dictionary
53 !!!#############################################################################
54 1 module function clip_setup( &
55 clip_min, clip_max, clip_norm) result(clip)
56 implicit none
57 real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm
58 type(clip_type) :: clip
59
60
61 !!--------------------------------------------------------------------------
62 !! set up clipping limits
63 !!--------------------------------------------------------------------------
64 1 if(present(clip_min))then
65 1 clip%l_min_max = .true.
66 1 clip%min = clip_min
67 end if
68
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(clip_max))then
69 1 clip%l_min_max = .true.
70 1 clip%max = clip_max
71 end if
72
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(clip_norm))then
73 1 clip%l_norm = .true.
74 1 clip%norm = clip_norm
75 end if
76
77
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
2 end function clip_setup
78 !!!#############################################################################
79
80 !!!#############################################################################
81 !!! get clipping information
82 !!!#############################################################################
83 1 subroutine read_clip(this, min_str, max_str, norm_str)
84 implicit none
85 class(clip_type), intent(inout) :: this
86 character(*), intent(in) :: min_str, max_str, norm_str
87
88
2/4
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
1 if(trim(min_str).ne."")then
89 1 read(min_str,*) this%min
90 else
91 this%min = -huge(1._real12)
92 end if
93
2/4
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 if(trim(max_str).ne."")then
94 1 read(max_str,*) this%max
95 else
96 this%max = huge(1._real12)
97 end if
98
99
3/6
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
1 if(trim(min_str).ne."".or.trim(max_str).ne."")then
100 1 this%l_min_max = .true.
101 end if
102
2/4
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
1 if(trim(norm_str).ne."")then
103 1 read(norm_str,*) this%norm
104 1 this%l_norm = .true.
105 end if
106
107 1 end subroutine read_clip
108 !!!#############################################################################
109
110
111 !!!#############################################################################
112 !!! set clip dictionary
113 !!!#############################################################################
114 2 subroutine set_clip(this, clip_dict, clip_min, clip_max, clip_norm)
115 implicit none
116 class(clip_type), intent(inout) :: this
117 type(clip_type), optional, intent(in) :: clip_dict
118 real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm
119
120
121 !!--------------------------------------------------------------------------
122 !! set up clipping limits
123 !!--------------------------------------------------------------------------
124
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 if(present(clip_dict))then
125 1 this%l_min_max = clip_dict%l_min_max
126 1 this%l_norm = clip_dict%l_norm
127 1 this%min = clip_dict%min
128 1 this%max = clip_dict%max
129 1 this%norm = clip_dict%norm
130
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(clip_min).or.present(clip_max).or.present(clip_norm))then
131 write(*,*) "Multiple clip options provided"
132 write(*,*) "Ignoring all except clip_dict"
133 end if
134 else
135
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(clip_min))then
136 1 this%l_min_max = .true.
137 1 this%min = clip_min
138 end if
139
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(clip_max))then
140 1 this%l_min_max = .true.
141 1 this%max = clip_max
142 end if
143
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(clip_norm))then
144 1 this%l_norm = .true.
145 1 this%norm = clip_norm
146 end if
147 end if
148
149 2 end subroutine set_clip
150 !!!#############################################################################
151
152
153 !!!#############################################################################
154 !!! gradient norm clipping
155 !!!#############################################################################
156
3/6
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
1 pure subroutine apply_clip(this, length, gradient, bias)
157 implicit none
158 class(clip_type), intent(in) :: this
159 integer, intent(in) :: length
160 real(real12), dimension(length), intent(inout) :: gradient
161 real(real12), dimension(:), optional, intent(inout) :: bias
162
163 real(real12) :: scale
164 1 real(real12), dimension(:), allocatable :: bias_
165
166
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(bias))then
167
8/16
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✓ Branch 18 taken 1 times.
2 bias_ = bias
168 else
169 allocate(bias_(1), source=0._real12)
170 end if
171
172 !! clip values to within limits of (min,max)
173
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%l_min_max)then
174
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 gradient = max(this%min,min(this%max,gradient))
175
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 1 times.
2 bias_ = max(this%min,min(this%max,bias_))
176 end if
177
178 !! clip values to a maximum L2-norm
179
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%l_norm)then
180 scale = min(1._real12, &
181 this%norm/sqrt(sum(gradient**2._real12) + &
182
11/18
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 3 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✓ Branch 20 taken 1 times.
✓ Branch 21 taken 1 times.
5 sum(bias_)**2._real12))
183
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(scale.lt.1._real12)then
184
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
4 gradient = gradient * scale
185
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 1 times.
2 bias_ = bias_ * scale
186 end if
187 end if
188
189
12/22
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✓ Branch 14 taken 1 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✓ Branch 28 taken 1 times.
✓ Branch 29 taken 1 times.
2 if(present(bias)) bias = bias_
190
191
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 end subroutine apply_clip
192 !!!#############################################################################
193
194 end module clipper
195