Line |
Branch |
Exec |
Source |
1 |
|
|
!!!############################################################################# |
2 |
|
|
!!! Code written by Ned Thaddeus Taylor |
3 |
|
|
!!! Code part of the ATHENA library - a feedforward neural network library |
4 |
|
|
!!!############################################################################# |
5 |
|
|
!!! module contains the network class, which is used to define a neural network |
6 |
|
|
!!! module contains the following derived types: |
7 |
|
|
!!! - network_type |
8 |
|
|
!!!################## |
9 |
|
|
!!! network_type contains the following procedures: |
10 |
|
|
!!! - print - print the network to file |
11 |
|
|
!!! - read - read the network from a file |
12 |
|
|
!!! - add - add a layer to the network |
13 |
|
|
!!! - reset - reset the network |
14 |
|
|
!!! - compile - compile the network |
15 |
|
|
!!! - set_batch_size - set batch size |
16 |
|
|
!!! - set_metrics - set network metrics |
17 |
|
|
!!! - set_loss - set network loss method |
18 |
|
|
!!! - train - train the network |
19 |
|
|
!!! - test - test the network |
20 |
|
|
!!! - predict - return predicted results from supplied inputs using ... |
21 |
|
|
!!! ... the trained network |
22 |
|
|
!!! - update - update the learnable parameters of the network based ... |
23 |
|
|
!!! ... on gradients |
24 |
|
|
!!! - reduce - reduce two networks down to one ... |
25 |
|
|
!!! ... (i.e. add two networks - parallel) |
26 |
|
|
!!! - copy - copy a network |
27 |
|
|
!!! - get_num_params - get number of learnable parameters in the network |
28 |
|
|
!!! - get_params - get learnable parameters |
29 |
|
|
!!! - set_params - set learnable parameters |
30 |
|
|
!!! - get_gradients - get gradients of learnable parameters |
31 |
|
|
!!! - set_gradients - set learnable parameter gradients |
32 |
|
|
!!! - reset_gradients - reset learnable parameter gradients |
33 |
|
|
!!! - forward - forward pass |
34 |
|
|
!!! - backward - backward pass |
35 |
|
|
!!!############################################################################# |
36 |
|
|
module network |
37 |
|
|
use constants, only: real12 |
38 |
|
|
use metrics, only: metric_dict_type |
39 |
|
|
use optimiser, only: base_optimiser_type |
40 |
|
|
use loss, only: & |
41 |
|
|
comp_loss_func => compute_loss_function, & |
42 |
|
|
comp_loss_deriv => compute_loss_derivative |
43 |
|
|
use base_layer, only: base_layer_type |
44 |
|
|
use container_layer, only: container_layer_type |
45 |
|
|
implicit none |
46 |
|
|
|
47 |
|
|
private |
48 |
|
|
|
49 |
|
|
public :: network_type |
50 |
|
|
|
51 |
|
|
|
52 |
|
|
type :: network_type |
53 |
|
|
real(real12) :: accuracy, loss |
54 |
|
|
integer :: batch_size = 0 |
55 |
|
|
integer :: num_layers = 0 |
56 |
|
|
integer :: num_outputs = 0 |
57 |
|
|
class(base_optimiser_type), allocatable :: optimiser |
58 |
|
|
type(metric_dict_type), dimension(2) :: metrics |
59 |
|
|
type(container_layer_type), allocatable, dimension(:) :: model |
60 |
|
|
procedure(comp_loss_func), nopass, pointer :: get_loss => null() |
61 |
|
|
procedure(comp_loss_deriv), nopass, pointer :: get_loss_deriv => null() |
62 |
|
|
contains |
63 |
|
|
procedure, pass(this) :: print |
64 |
|
|
procedure, pass(this) :: read |
65 |
|
|
procedure, pass(this) :: add |
66 |
|
|
procedure, pass(this) :: reset |
67 |
|
|
procedure, pass(this) :: compile |
68 |
|
|
procedure, pass(this) :: set_batch_size |
69 |
|
|
procedure, pass(this) :: set_metrics |
70 |
|
|
procedure, pass(this) :: set_loss |
71 |
|
|
procedure, pass(this) :: train |
72 |
|
|
procedure, pass(this) :: test |
73 |
|
|
procedure, pass(this) :: predict => predict_1d |
74 |
|
|
procedure, pass(this) :: update |
75 |
|
|
|
76 |
|
|
procedure, pass(this) :: reduce => network_reduction |
77 |
|
|
procedure, pass(this) :: copy => network_copy |
78 |
|
|
|
79 |
|
|
procedure, pass(this) :: get_num_params |
80 |
|
|
procedure, pass(this) :: get_params |
81 |
|
|
procedure, pass(this) :: set_params |
82 |
|
|
procedure, pass(this) :: get_gradients |
83 |
|
|
procedure, pass(this) :: set_gradients |
84 |
|
|
procedure, pass(this) :: reset_gradients |
85 |
|
|
|
86 |
|
|
procedure, pass(this) :: forward => forward_1d |
87 |
|
|
procedure, pass(this) :: backward => backward_1d |
88 |
|
|
end type network_type |
89 |
|
|
|
90 |
|
|
interface network_type |
91 |
|
|
!!------------------------------------------------------------------------- |
92 |
|
|
!! setup the network (network initialisation) |
93 |
|
|
!!------------------------------------------------------------------------- |
94 |
|
|
!! layers = (T, in) layer container |
95 |
|
|
!! optimiser = (T, in, opt) optimiser |
96 |
|
|
!! loss_method = (S, in, opt) loss method |
97 |
|
|
!! metrics = (*, in, opt) metrics, either string or metric_dict_type |
98 |
|
|
!! batch_size = (I, in, opt) batch size |
99 |
|
|
module function network_setup( & |
100 |
|
|
layers, & |
101 |
|
|
optimiser, loss_method, metrics, batch_size) result(network) |
102 |
|
|
type(container_layer_type), dimension(:), intent(in) :: layers |
103 |
|
|
class(base_optimiser_type), optional, intent(in) :: optimiser |
104 |
|
|
character(*), optional, intent(in) :: loss_method |
105 |
|
|
class(*), dimension(..), optional, intent(in) :: metrics |
106 |
|
|
integer, optional, intent(in) :: batch_size |
107 |
|
|
type(network_type) :: network |
108 |
|
|
end function network_setup |
109 |
|
|
end interface network_type |
110 |
|
|
|
111 |
|
|
interface |
112 |
|
|
!!------------------------------------------------------------------------- |
113 |
|
|
!! print the network to file |
114 |
|
|
!!------------------------------------------------------------------------- |
115 |
|
|
!! this = (T, in) network type |
116 |
|
|
!! file = (I, in) file name |
117 |
|
|
module subroutine print(this, file) |
118 |
|
|
class(network_type), intent(in) :: this |
119 |
|
|
character(*), intent(in) :: file |
120 |
|
|
end subroutine print |
121 |
|
|
|
122 |
|
|
!!------------------------------------------------------------------------- |
123 |
|
|
!! read the network from a file |
124 |
|
|
!!------------------------------------------------------------------------- |
125 |
|
|
!! this = (T, io) network type |
126 |
|
|
!! file = (I, in) file name |
127 |
|
|
module subroutine read(this, file) |
128 |
|
|
class(network_type), intent(inout) :: this |
129 |
|
|
character(*), intent(in) :: file |
130 |
|
|
end subroutine read |
131 |
|
|
|
132 |
|
|
!!------------------------------------------------------------------------- |
133 |
|
|
!! add a layer to the network |
134 |
|
|
!!------------------------------------------------------------------------- |
135 |
|
|
!! this = (T, io) network type |
136 |
|
|
!! layer = (I, in) layer to add |
137 |
|
|
module subroutine add(this, layer) |
138 |
|
|
class(network_type), intent(inout) :: this |
139 |
|
|
class(base_layer_type), intent(in) :: layer |
140 |
|
|
end subroutine add |
141 |
|
|
|
142 |
|
|
!!------------------------------------------------------------------------- |
143 |
|
|
!! reset the network |
144 |
|
|
!!------------------------------------------------------------------------- |
145 |
|
|
!! this = (T, io) network type |
146 |
|
|
module subroutine reset(this) |
147 |
|
|
class(network_type), intent(inout) :: this |
148 |
|
|
end subroutine reset |
149 |
|
|
|
150 |
|
|
!!------------------------------------------------------------------------- |
151 |
|
|
!! compile the network |
152 |
|
|
!!------------------------------------------------------------------------- |
153 |
|
|
!! this = (T, io) network type |
154 |
|
|
!! optimiser = (T, in) optimiser |
155 |
|
|
!! loss_method = (S, in, opt) loss method |
156 |
|
|
!! metrics = (*, in, opt) metrics, either string or metric_dict_type |
157 |
|
|
!! batch_size = (I, in, opt) batch size |
158 |
|
|
!! verbose = (I, in, opt) verbosity level |
159 |
|
|
module subroutine compile(this, optimiser, loss_method, metrics, & |
160 |
|
|
batch_size, verbose) |
161 |
|
|
class(network_type), intent(inout) :: this |
162 |
|
|
class(base_optimiser_type), intent(in) :: optimiser |
163 |
|
|
character(*), optional, intent(in) :: loss_method |
164 |
|
|
class(*), dimension(..), optional, intent(in) :: metrics |
165 |
|
|
integer, optional, intent(in) :: batch_size |
166 |
|
|
integer, optional, intent(in) :: verbose |
167 |
|
|
end subroutine compile |
168 |
|
|
|
169 |
|
|
!!------------------------------------------------------------------------- |
170 |
|
|
!! set batch size |
171 |
|
|
!!------------------------------------------------------------------------- |
172 |
|
|
!! this = (T, io) network type |
173 |
|
|
!! batch_size = (I, in) batch size to use |
174 |
|
|
module subroutine set_batch_size(this, batch_size) |
175 |
|
|
class(network_type), intent(inout) :: this |
176 |
|
|
integer, intent(in) :: batch_size |
177 |
|
|
end subroutine set_batch_size |
178 |
|
|
|
179 |
|
|
!!------------------------------------------------------------------------- |
180 |
|
|
!! set network metrics |
181 |
|
|
!!------------------------------------------------------------------------- |
182 |
|
|
!! this = (T, io) network type |
183 |
|
|
!! metrics = (*, in) metrics to use |
184 |
|
|
module subroutine set_metrics(this, metrics) |
185 |
|
|
class(network_type), intent(inout) :: this |
186 |
|
|
class(*), dimension(..), intent(in) :: metrics |
187 |
|
|
end subroutine set_metrics |
188 |
|
|
|
189 |
|
|
!!------------------------------------------------------------------------- |
190 |
|
|
!! set network loss method |
191 |
|
|
!!------------------------------------------------------------------------- |
192 |
|
|
!! this = (T, io) network type |
193 |
|
|
!! loss_method = (S, in) loss method to use |
194 |
|
|
!! verbose = (I, in, opt) verbosity level |
195 |
|
|
module subroutine set_loss(this, loss_method, verbose) |
196 |
|
|
class(network_type), intent(inout) :: this |
197 |
|
|
character(*), intent(in) :: loss_method |
198 |
|
|
integer, optional, intent(in) :: verbose |
199 |
|
|
end subroutine set_loss |
200 |
|
|
|
201 |
|
|
!!------------------------------------------------------------------------- |
202 |
|
|
!! train the network |
203 |
|
|
!!------------------------------------------------------------------------- |
204 |
|
|
!! this = (T, io) network type |
205 |
|
|
!! input = (R, in) input data |
206 |
|
|
!! output = (*, in) expected output data (data labels) |
207 |
|
|
!! num_epochs = (I, in) number of epochs to train for |
208 |
|
|
!! batch_size = (I, in, opt) batch size (DEPRECATED) |
209 |
|
|
!! addit_input = (R, in, opt) additional input data |
210 |
|
|
!! addit_layer = (I, in, opt) layer to insert additional input data |
211 |
|
|
!! plateau_threshold = (R, in, opt) threshold for checking learning plateau |
212 |
|
|
!! shuffle_batches = (B, in, opt) shuffle batch order |
213 |
|
|
!! batch_print_step = (I, in, opt) print step for batch |
214 |
|
|
!! verbose = (I, in, opt) verbosity level |
215 |
|
|
module subroutine train(this, input, output, num_epochs, batch_size, & |
216 |
|
|
addit_input, addit_layer, & |
217 |
|
|
plateau_threshold, shuffle_batches, batch_print_step, verbose) |
218 |
|
|
class(network_type), intent(inout) :: this |
219 |
|
|
real(real12), dimension(..), intent(in) :: input |
220 |
|
|
class(*), dimension(:,:), intent(in) :: output |
221 |
|
|
integer, intent(in) :: num_epochs |
222 |
|
|
integer, optional, intent(in) :: batch_size !! deprecated |
223 |
|
|
real(real12), dimension(:,:), optional, intent(in) :: addit_input |
224 |
|
|
integer, optional, intent(in) :: addit_layer |
225 |
|
|
real(real12), optional, intent(in) :: plateau_threshold |
226 |
|
|
logical, optional, intent(in) :: shuffle_batches |
227 |
|
|
integer, optional, intent(in) :: batch_print_step |
228 |
|
|
integer, optional, intent(in) :: verbose |
229 |
|
|
end subroutine train |
230 |
|
|
|
231 |
|
|
!!------------------------------------------------------------------------- |
232 |
|
|
!! test the network |
233 |
|
|
!!------------------------------------------------------------------------- |
234 |
|
|
!! this = (T, io) network type |
235 |
|
|
!! input = (R, in) input data |
236 |
|
|
!! output = (*, in) expected output data (data labels) |
237 |
|
|
!! addit_input = (R, in, opt) additional input data |
238 |
|
|
!! addit_layer = (I, in, opt) layer to insert additional input data |
239 |
|
|
!! verbose = (I, in, opt) verbosity level |
240 |
|
|
module subroutine test(this, input, output, & |
241 |
|
|
addit_input, addit_layer, & |
242 |
|
|
verbose) |
243 |
|
|
class(network_type), intent(inout) :: this |
244 |
|
|
real(real12), dimension(..), intent(in) :: input |
245 |
|
|
class(*), dimension(:,:), intent(in) :: output |
246 |
|
|
real(real12), dimension(:,:), optional, intent(in) :: addit_input |
247 |
|
|
integer, optional, intent(in) :: addit_layer |
248 |
|
|
integer, optional, intent(in) :: verbose |
249 |
|
|
end subroutine test |
250 |
|
|
|
251 |
|
|
!!------------------------------------------------------------------------- |
252 |
|
|
!! return predicted results from supplied inputs using the trained network |
253 |
|
|
!!------------------------------------------------------------------------- |
254 |
|
|
!! this = (T, in) network type |
255 |
|
|
!! input = (R, in) input data |
256 |
|
|
!! addit_input = (R, in, opt) additional input data |
257 |
|
|
!! addit_layer = (I, in, opt) layer to insert additional input data |
258 |
|
|
!! verbose = (I, in, opt) verbosity level |
259 |
|
|
!! output = (R, out) predicted output data |
260 |
|
|
module function predict_1d(this, input, & |
261 |
|
|
addit_input, addit_layer, & |
262 |
|
|
verbose) result(output) |
263 |
|
|
class(network_type), intent(inout) :: this |
264 |
|
|
real(real12), dimension(..), intent(in) :: input |
265 |
|
|
real(real12), dimension(:,:), optional, intent(in) :: addit_input |
266 |
|
|
integer, optional, intent(in) :: addit_layer |
267 |
|
|
integer, optional, intent(in) :: verbose |
268 |
|
|
real(real12), dimension(:,:), allocatable :: output |
269 |
|
|
end function predict_1d |
270 |
|
|
|
271 |
|
|
!!------------------------------------------------------------------------- |
272 |
|
|
!! update the learnable parameters of the network based on gradients |
273 |
|
|
!!------------------------------------------------------------------------- |
274 |
|
|
!! this = (T, io) network type |
275 |
|
|
module subroutine update(this) |
276 |
|
|
class(network_type), intent(inout) :: this |
277 |
|
|
end subroutine update |
278 |
|
|
|
279 |
|
|
!!------------------------------------------------------------------------- |
280 |
|
|
!! reduce two networks down to one (i.e. add two networks - parallel) |
281 |
|
|
!!------------------------------------------------------------------------- |
282 |
|
|
!! this = (T, io) network type, resultant network of the reduction |
283 |
|
|
!! source = (T, in) network type |
284 |
|
|
module subroutine network_reduction(this, source) |
285 |
|
|
class(network_type), intent(inout) :: this |
286 |
|
|
type(network_type), intent(in) :: source |
287 |
|
|
end subroutine network_reduction |
288 |
|
|
|
289 |
|
|
!!------------------------------------------------------------------------- |
290 |
|
|
!! copy a network |
291 |
|
|
!!------------------------------------------------------------------------- |
292 |
|
|
!! this = (T, io) network type, resultant network of the copy |
293 |
|
|
!! source = (T, in) network type |
294 |
|
|
module subroutine network_copy(this, source) |
295 |
|
|
class(network_type), intent(inout) :: this |
296 |
|
|
type(network_type), intent(in) :: source |
297 |
|
|
end subroutine network_copy |
298 |
|
|
|
299 |
|
|
!!------------------------------------------------------------------------- |
300 |
|
|
!! get number of learnable parameters in the network |
301 |
|
|
!!------------------------------------------------------------------------- |
302 |
|
|
!! this = (T, in) network type |
303 |
|
|
!! num_params = (I, out) number of parameters |
304 |
|
|
pure module function get_num_params(this) result(num_params) |
305 |
|
|
class(network_type), intent(in) :: this |
306 |
|
|
integer :: num_params |
307 |
|
|
end function get_num_params |
308 |
|
|
|
309 |
|
|
!!------------------------------------------------------------------------- |
310 |
|
|
!! get learnable parameters |
311 |
|
|
!!------------------------------------------------------------------------- |
312 |
|
|
!! this = (T, in) network type |
313 |
|
|
!! params = (R, out) learnable parameters |
314 |
|
|
pure module function get_params(this) result(params) |
315 |
|
|
class(network_type), intent(in) :: this |
316 |
|
|
real(real12), allocatable, dimension(:) :: params |
317 |
|
|
end function get_params |
318 |
|
|
|
319 |
|
|
!!------------------------------------------------------------------------- |
320 |
|
|
!! set learnable parameters |
321 |
|
|
!!------------------------------------------------------------------------- |
322 |
|
|
!! this = (T, io) network type |
323 |
|
|
!! params = (R, in) learnable parameters |
324 |
|
|
!! verbose = (I, in, opt) verbosity level |
325 |
|
|
module subroutine set_params(this, params) |
326 |
|
|
class(network_type), intent(inout) :: this |
327 |
|
|
real(real12), dimension(:), intent(in) :: params |
328 |
|
|
end subroutine set_params |
329 |
|
|
|
330 |
|
|
!!------------------------------------------------------------------------- |
331 |
|
|
!! get gradients of learnable parameters |
332 |
|
|
!!------------------------------------------------------------------------- |
333 |
|
|
!! this = (T, in) network type |
334 |
|
|
!! gradients = (R, out) gradients |
335 |
|
|
pure module function get_gradients(this) result(gradients) |
336 |
|
|
class(network_type), intent(in) :: this |
337 |
|
|
real(real12), allocatable, dimension(:) :: gradients |
338 |
|
|
end function get_gradients |
339 |
|
|
|
340 |
|
|
!!------------------------------------------------------------------------- |
341 |
|
|
!! set learnable parameter gradients |
342 |
|
|
!!------------------------------------------------------------------------- |
343 |
|
|
!! this = (T, io) network type |
344 |
|
|
!! gradients = (R, in) gradients |
345 |
|
|
!! verbose = (I, in, opt) verbosity level |
346 |
|
|
module subroutine set_gradients(this, gradients) |
347 |
|
|
class(network_type), intent(inout) :: this |
348 |
|
|
real(real12), dimension(..), intent(in) :: gradients |
349 |
|
|
end subroutine set_gradients |
350 |
|
|
|
351 |
|
|
!!------------------------------------------------------------------------- |
352 |
|
|
!! reset learnable parameter gradients |
353 |
|
|
!!------------------------------------------------------------------------- |
354 |
|
|
!! this = (T, io) network type |
355 |
|
|
!! verbose = (I, in, opt) verbosity level |
356 |
|
|
!!------------------------------------------------------------------------- |
357 |
|
|
module subroutine reset_gradients(this) |
358 |
|
|
class(network_type), intent(inout) :: this |
359 |
|
|
end subroutine reset_gradients |
360 |
|
|
|
361 |
|
|
!!------------------------------------------------------------------------- |
362 |
|
|
!! forward pass |
363 |
|
|
!!------------------------------------------------------------------------- |
364 |
|
|
!! this = (T, io) network type |
365 |
|
|
!! input = (R, in) input data |
366 |
|
|
!! addit_input = (R, in, opt) additional input data |
367 |
|
|
!! layer = (I, in, opt) layer to insert additional input data |
368 |
|
|
pure module subroutine forward_1d(this, input, addit_input, layer) |
369 |
|
|
class(network_type), intent(inout) :: this |
370 |
|
|
real(real12), dimension(..), intent(in) :: input |
371 |
|
|
real(real12), dimension(:,:), optional, intent(in) :: addit_input |
372 |
|
|
integer, optional, intent(in) :: layer |
373 |
|
|
end subroutine forward_1d |
374 |
|
|
|
375 |
|
|
!!------------------------------------------------------------------------- |
376 |
|
|
!! backward pass |
377 |
|
|
!!------------------------------------------------------------------------- |
378 |
|
|
!! this = (T, io) network type |
379 |
|
|
!! output = (R, in) output data |
380 |
|
|
pure module subroutine backward_1d(this, output) |
381 |
|
|
class(network_type), intent(inout) :: this |
382 |
|
|
real(real12), dimension(:,:), intent(in) :: output |
383 |
|
|
end subroutine backward_1d |
384 |
|
|
end interface |
385 |
|
|
|
386 |
|
✗ |
end module network |
387 |
|
|
!!!############################################################################# |
388 |
|
|
|