Line | Branch | Exec | Source |
---|---|---|---|
1 | !!!############################################################################# | ||
2 | !!! Code written by Ned Thaddeus Taylor | ||
3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
4 | !!!############################################################################# | ||
5 | !!! module contains routines for clipping gradients | ||
6 | !!! module includes the following types: | ||
7 | !!! - clip_type - type containing clipping information | ||
8 | !!!################## | ||
9 | !!! clip_type contains the following procedures: | ||
10 | !!! - read_clip - read clipping information from strings | ||
11 | !!! - set_clip - set clipping information from a dictionary | ||
12 | !!! - apply_clip - apply clipping to gradients | ||
13 | !!!############################################################################# | ||
14 | module clipper | ||
15 | use constants, only: real12 | ||
16 | implicit none | ||
17 | |||
18 | |||
19 | !!!------------------------------------------------------------------------ | ||
20 | !!! gradient clipping type | ||
21 | !!!------------------------------------------------------------------------ | ||
22 | type clip_type | ||
23 | logical :: l_min_max = .false. | ||
24 | logical :: l_norm = .false. | ||
25 | real(real12) :: min =-huge(1._real12) | ||
26 | real(real12) :: max = huge(1._real12) | ||
27 | real(real12) :: norm = huge(1._real12) | ||
28 | contains | ||
29 | procedure, pass(this) :: read => read_clip | ||
30 | procedure, pass(this) :: set => set_clip | ||
31 | procedure, pass(this) :: apply => apply_clip | ||
32 | end type clip_type | ||
33 | |||
34 | interface clip_type | ||
35 | module function clip_setup( & | ||
36 | clip_min, clip_max, clip_norm) result(clip) | ||
37 | real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
38 | type(clip_type) :: clip | ||
39 | end function clip_setup | ||
40 | end interface clip_type | ||
41 | |||
42 | |||
43 | |||
44 | private | ||
45 | |||
46 | public :: clip_type | ||
47 | |||
48 | |||
49 | contains | ||
50 | |||
51 | !!!############################################################################# | ||
52 | !!! set clip dictionary | ||
53 | !!!############################################################################# | ||
54 | 1 | module function clip_setup( & | |
55 | clip_min, clip_max, clip_norm) result(clip) | ||
56 | implicit none | ||
57 | real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
58 | type(clip_type) :: clip | ||
59 | |||
60 | |||
61 | !!-------------------------------------------------------------------------- | ||
62 | !! set up clipping limits | ||
63 | !!-------------------------------------------------------------------------- | ||
64 | 1 | if(present(clip_min))then | |
65 | 1 | clip%l_min_max = .true. | |
66 | 1 | clip%min = clip_min | |
67 | end if | ||
68 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(clip_max))then |
69 | 1 | clip%l_min_max = .true. | |
70 | 1 | clip%max = clip_max | |
71 | end if | ||
72 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(clip_norm))then |
73 | 1 | clip%l_norm = .true. | |
74 | 1 | clip%norm = clip_norm | |
75 | end if | ||
76 | |||
77 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
2 | end function clip_setup |
78 | !!!############################################################################# | ||
79 | |||
80 | !!!############################################################################# | ||
81 | !!! get clipping information | ||
82 | !!!############################################################################# | ||
83 | 1 | subroutine read_clip(this, min_str, max_str, norm_str) | |
84 | implicit none | ||
85 | class(clip_type), intent(inout) :: this | ||
86 | character(*), intent(in) :: min_str, max_str, norm_str | ||
87 | |||
88 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | if(trim(min_str).ne."")then |
89 | 1 | read(min_str,*) this%min | |
90 | else | ||
91 | ✗ | this%min = -huge(1._real12) | |
92 | end if | ||
93 |
2/4✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | if(trim(max_str).ne."")then |
94 | 1 | read(max_str,*) this%max | |
95 | else | ||
96 | ✗ | this%max = huge(1._real12) | |
97 | end if | ||
98 | |||
99 |
3/6✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
|
1 | if(trim(min_str).ne."".or.trim(max_str).ne."")then |
100 | 1 | this%l_min_max = .true. | |
101 | end if | ||
102 |
2/4✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
|
1 | if(trim(norm_str).ne."")then |
103 | 1 | read(norm_str,*) this%norm | |
104 | 1 | this%l_norm = .true. | |
105 | end if | ||
106 | |||
107 | 1 | end subroutine read_clip | |
108 | !!!############################################################################# | ||
109 | |||
110 | |||
111 | !!!############################################################################# | ||
112 | !!! set clip dictionary | ||
113 | !!!############################################################################# | ||
114 | 2 | subroutine set_clip(this, clip_dict, clip_min, clip_max, clip_norm) | |
115 | implicit none | ||
116 | class(clip_type), intent(inout) :: this | ||
117 | type(clip_type), optional, intent(in) :: clip_dict | ||
118 | real(real12), optional, intent(in) :: clip_min, clip_max, clip_norm | ||
119 | |||
120 | |||
121 | !!-------------------------------------------------------------------------- | ||
122 | !! set up clipping limits | ||
123 | !!-------------------------------------------------------------------------- | ||
124 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | if(present(clip_dict))then |
125 | 1 | this%l_min_max = clip_dict%l_min_max | |
126 | 1 | this%l_norm = clip_dict%l_norm | |
127 | 1 | this%min = clip_dict%min | |
128 | 1 | this%max = clip_dict%max | |
129 | 1 | this%norm = clip_dict%norm | |
130 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(clip_min).or.present(clip_max).or.present(clip_norm))then |
131 | ✗ | write(*,*) "Multiple clip options provided" | |
132 | ✗ | write(*,*) "Ignoring all except clip_dict" | |
133 | end if | ||
134 | else | ||
135 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(clip_min))then |
136 | 1 | this%l_min_max = .true. | |
137 | 1 | this%min = clip_min | |
138 | end if | ||
139 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(clip_max))then |
140 | 1 | this%l_min_max = .true. | |
141 | 1 | this%max = clip_max | |
142 | end if | ||
143 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(clip_norm))then |
144 | 1 | this%l_norm = .true. | |
145 | 1 | this%norm = clip_norm | |
146 | end if | ||
147 | end if | ||
148 | |||
149 | 2 | end subroutine set_clip | |
150 | !!!############################################################################# | ||
151 | |||
152 | |||
153 | !!!############################################################################# | ||
154 | !!! gradient norm clipping | ||
155 | !!!############################################################################# | ||
156 |
3/6✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
|
1 | pure subroutine apply_clip(this, length, gradient, bias) |
157 | implicit none | ||
158 | class(clip_type), intent(in) :: this | ||
159 | integer, intent(in) :: length | ||
160 | real(real12), dimension(length), intent(inout) :: gradient | ||
161 | real(real12), dimension(:), optional, intent(inout) :: bias | ||
162 | |||
163 | real(real12) :: scale | ||
164 | 1 | real(real12), dimension(:), allocatable :: bias_ | |
165 | |||
166 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(bias))then |
167 |
8/16✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✓ Branch 18 taken 1 times.
|
2 | bias_ = bias |
168 | else | ||
169 | ✗ | allocate(bias_(1), source=0._real12) | |
170 | end if | ||
171 | |||
172 | !! clip values to within limits of (min,max) | ||
173 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%l_min_max)then |
174 |
9/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
|
4 | gradient = max(this%min,min(this%max,gradient)) |
175 |
6/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 1 times.
|
2 | bias_ = max(this%min,min(this%max,bias_)) |
176 | end if | ||
177 | |||
178 | !! clip values to a maximum L2-norm | ||
179 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(this%l_norm)then |
180 | scale = min(1._real12, & | ||
181 | ✗ | this%norm/sqrt(sum(gradient**2._real12) + & | |
182 |
11/18✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 3 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✓ Branch 20 taken 1 times.
✓ Branch 21 taken 1 times.
|
5 | sum(bias_)**2._real12)) |
183 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(scale.lt.1._real12)then |
184 |
9/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 3 times.
✓ Branch 22 taken 1 times.
|
4 | gradient = gradient * scale |
185 |
6/10✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✓ Branch 13 taken 1 times.
|
2 | bias_ = bias_ * scale |
186 | end if | ||
187 | end if | ||
188 | |||
189 |
12/22✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✓ Branch 14 taken 1 times.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✓ Branch 28 taken 1 times.
✓ Branch 29 taken 1 times.
|
2 | if(present(bias)) bias = bias_ |
190 | |||
191 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | end subroutine apply_clip |
192 | !!!############################################################################# | ||
193 | |||
194 | ✗ | end module clipper | |
195 |