GCC Code Coverage Report


Directory: src/lib/
File: src/lib/mod_network.f90
Date: 2024-06-28 12:51:18
Exec Total Coverage
Lines: 0 1 0.0%
Functions: 0 0 -%
Branches: 0 86 0.0%

Line Branch Exec Source
1 !!!#############################################################################
2 !!! Code written by Ned Thaddeus Taylor
3 !!! Code part of the ATHENA library - a feedforward neural network library
4 !!!#############################################################################
5 !!! module contains the network class, which is used to define a neural network
6 !!! module contains the following derived types:
7 !!! - network_type
8 !!!##################
9 !!! network_type contains the following procedures:
10 !!! - print - print the network to file
11 !!! - read - read the network from a file
12 !!! - add - add a layer to the network
13 !!! - reset - reset the network
14 !!! - compile - compile the network
15 !!! - set_batch_size - set batch size
16 !!! - set_metrics - set network metrics
17 !!! - set_loss - set network loss method
18 !!! - train - train the network
19 !!! - test - test the network
20 !!! - predict - return predicted results from supplied inputs using ...
21 !!! ... the trained network
22 !!! - update - update the learnable parameters of the network based ...
23 !!! ... on gradients
24 !!! - reduce - reduce two networks down to one ...
25 !!! ... (i.e. add two networks - parallel)
26 !!! - copy - copy a network
27 !!! - get_num_params - get number of learnable parameters in the network
28 !!! - get_params - get learnable parameters
29 !!! - set_params - set learnable parameters
30 !!! - get_gradients - get gradients of learnable parameters
31 !!! - set_gradients - set learnable parameter gradients
32 !!! - reset_gradients - reset learnable parameter gradients
33 !!! - forward - forward pass
34 !!! - backward - backward pass
35 !!!#############################################################################
36 module network
37 use constants, only: real12
38 use metrics, only: metric_dict_type
39 use optimiser, only: base_optimiser_type
40 use loss, only: &
41 comp_loss_func => compute_loss_function, &
42 comp_loss_deriv => compute_loss_derivative
43 use base_layer, only: base_layer_type
44 use container_layer, only: container_layer_type
45 implicit none
46
47 private
48
49 public :: network_type
50
51
52 type :: network_type
53 real(real12) :: accuracy, loss
54 integer :: batch_size = 0
55 integer :: num_layers = 0
56 integer :: num_outputs = 0
57 class(base_optimiser_type), allocatable :: optimiser
58 type(metric_dict_type), dimension(2) :: metrics
59 type(container_layer_type), allocatable, dimension(:) :: model
60 procedure(comp_loss_func), nopass, pointer :: get_loss => null()
61 procedure(comp_loss_deriv), nopass, pointer :: get_loss_deriv => null()
62 contains
63 procedure, pass(this) :: print
64 procedure, pass(this) :: read
65 procedure, pass(this) :: add
66 procedure, pass(this) :: reset
67 procedure, pass(this) :: compile
68 procedure, pass(this) :: set_batch_size
69 procedure, pass(this) :: set_metrics
70 procedure, pass(this) :: set_loss
71 procedure, pass(this) :: train
72 procedure, pass(this) :: test
73 procedure, pass(this) :: predict => predict_1d
74 procedure, pass(this) :: update
75
76 procedure, pass(this) :: reduce => network_reduction
77 procedure, pass(this) :: copy => network_copy
78
79 procedure, pass(this) :: get_num_params
80 procedure, pass(this) :: get_params
81 procedure, pass(this) :: set_params
82 procedure, pass(this) :: get_gradients
83 procedure, pass(this) :: set_gradients
84 procedure, pass(this) :: reset_gradients
85
86 procedure, pass(this) :: forward => forward_1d
87 procedure, pass(this) :: backward => backward_1d
88 end type network_type
89
90 interface network_type
91 !!-------------------------------------------------------------------------
92 !! setup the network (network initialisation)
93 !!-------------------------------------------------------------------------
94 !! layers = (T, in) layer container
95 !! optimiser = (T, in, opt) optimiser
96 !! loss_method = (S, in, opt) loss method
97 !! metrics = (*, in, opt) metrics, either string or metric_dict_type
98 !! batch_size = (I, in, opt) batch size
99 module function network_setup( &
100 layers, &
101 optimiser, loss_method, metrics, batch_size) result(network)
102 type(container_layer_type), dimension(:), intent(in) :: layers
103 class(base_optimiser_type), optional, intent(in) :: optimiser
104 character(*), optional, intent(in) :: loss_method
105 class(*), dimension(..), optional, intent(in) :: metrics
106 integer, optional, intent(in) :: batch_size
107 type(network_type) :: network
108 end function network_setup
109 end interface network_type
110
111 interface
112 !!-------------------------------------------------------------------------
113 !! print the network to file
114 !!-------------------------------------------------------------------------
115 !! this = (T, in) network type
116 !! file = (I, in) file name
117 module subroutine print(this, file)
118 class(network_type), intent(in) :: this
119 character(*), intent(in) :: file
120 end subroutine print
121
122 !!-------------------------------------------------------------------------
123 !! read the network from a file
124 !!-------------------------------------------------------------------------
125 !! this = (T, io) network type
126 !! file = (I, in) file name
127 module subroutine read(this, file)
128 class(network_type), intent(inout) :: this
129 character(*), intent(in) :: file
130 end subroutine read
131
132 !!-------------------------------------------------------------------------
133 !! add a layer to the network
134 !!-------------------------------------------------------------------------
135 !! this = (T, io) network type
136 !! layer = (I, in) layer to add
137 module subroutine add(this, layer)
138 class(network_type), intent(inout) :: this
139 class(base_layer_type), intent(in) :: layer
140 end subroutine add
141
142 !!-------------------------------------------------------------------------
143 !! reset the network
144 !!-------------------------------------------------------------------------
145 !! this = (T, io) network type
146 module subroutine reset(this)
147 class(network_type), intent(inout) :: this
148 end subroutine reset
149
150 !!-------------------------------------------------------------------------
151 !! compile the network
152 !!-------------------------------------------------------------------------
153 !! this = (T, io) network type
154 !! optimiser = (T, in) optimiser
155 !! loss_method = (S, in, opt) loss method
156 !! metrics = (*, in, opt) metrics, either string or metric_dict_type
157 !! batch_size = (I, in, opt) batch size
158 !! verbose = (I, in, opt) verbosity level
159 module subroutine compile(this, optimiser, loss_method, metrics, &
160 batch_size, verbose)
161 class(network_type), intent(inout) :: this
162 class(base_optimiser_type), intent(in) :: optimiser
163 character(*), optional, intent(in) :: loss_method
164 class(*), dimension(..), optional, intent(in) :: metrics
165 integer, optional, intent(in) :: batch_size
166 integer, optional, intent(in) :: verbose
167 end subroutine compile
168
169 !!-------------------------------------------------------------------------
170 !! set batch size
171 !!-------------------------------------------------------------------------
172 !! this = (T, io) network type
173 !! batch_size = (I, in) batch size to use
174 module subroutine set_batch_size(this, batch_size)
175 class(network_type), intent(inout) :: this
176 integer, intent(in) :: batch_size
177 end subroutine set_batch_size
178
179 !!-------------------------------------------------------------------------
180 !! set network metrics
181 !!-------------------------------------------------------------------------
182 !! this = (T, io) network type
183 !! metrics = (*, in) metrics to use
184 module subroutine set_metrics(this, metrics)
185 class(network_type), intent(inout) :: this
186 class(*), dimension(..), intent(in) :: metrics
187 end subroutine set_metrics
188
189 !!-------------------------------------------------------------------------
190 !! set network loss method
191 !!-------------------------------------------------------------------------
192 !! this = (T, io) network type
193 !! loss_method = (S, in) loss method to use
194 !! verbose = (I, in, opt) verbosity level
195 module subroutine set_loss(this, loss_method, verbose)
196 class(network_type), intent(inout) :: this
197 character(*), intent(in) :: loss_method
198 integer, optional, intent(in) :: verbose
199 end subroutine set_loss
200
201 !!-------------------------------------------------------------------------
202 !! train the network
203 !!-------------------------------------------------------------------------
204 !! this = (T, io) network type
205 !! input = (R, in) input data
206 !! output = (*, in) expected output data (data labels)
207 !! num_epochs = (I, in) number of epochs to train for
208 !! batch_size = (I, in, opt) batch size (DEPRECATED)
209 !! addit_input = (R, in, opt) additional input data
210 !! addit_layer = (I, in, opt) layer to insert additional input data
211 !! plateau_threshold = (R, in, opt) threshold for checking learning plateau
212 !! shuffle_batches = (B, in, opt) shuffle batch order
213 !! batch_print_step = (I, in, opt) print step for batch
214 !! verbose = (I, in, opt) verbosity level
215 module subroutine train(this, input, output, num_epochs, batch_size, &
216 addit_input, addit_layer, &
217 plateau_threshold, shuffle_batches, batch_print_step, verbose)
218 class(network_type), intent(inout) :: this
219 real(real12), dimension(..), intent(in) :: input
220 class(*), dimension(:,:), intent(in) :: output
221 integer, intent(in) :: num_epochs
222 integer, optional, intent(in) :: batch_size !! deprecated
223 real(real12), dimension(:,:), optional, intent(in) :: addit_input
224 integer, optional, intent(in) :: addit_layer
225 real(real12), optional, intent(in) :: plateau_threshold
226 logical, optional, intent(in) :: shuffle_batches
227 integer, optional, intent(in) :: batch_print_step
228 integer, optional, intent(in) :: verbose
229 end subroutine train
230
231 !!-------------------------------------------------------------------------
232 !! test the network
233 !!-------------------------------------------------------------------------
234 !! this = (T, io) network type
235 !! input = (R, in) input data
236 !! output = (*, in) expected output data (data labels)
237 !! addit_input = (R, in, opt) additional input data
238 !! addit_layer = (I, in, opt) layer to insert additional input data
239 !! verbose = (I, in, opt) verbosity level
240 module subroutine test(this, input, output, &
241 addit_input, addit_layer, &
242 verbose)
243 class(network_type), intent(inout) :: this
244 real(real12), dimension(..), intent(in) :: input
245 class(*), dimension(:,:), intent(in) :: output
246 real(real12), dimension(:,:), optional, intent(in) :: addit_input
247 integer, optional, intent(in) :: addit_layer
248 integer, optional, intent(in) :: verbose
249 end subroutine test
250
251 !!-------------------------------------------------------------------------
252 !! return predicted results from supplied inputs using the trained network
253 !!-------------------------------------------------------------------------
254 !! this = (T, in) network type
255 !! input = (R, in) input data
256 !! addit_input = (R, in, opt) additional input data
257 !! addit_layer = (I, in, opt) layer to insert additional input data
258 !! verbose = (I, in, opt) verbosity level
259 !! output = (R, out) predicted output data
260 module function predict_1d(this, input, &
261 addit_input, addit_layer, &
262 verbose) result(output)
263 class(network_type), intent(inout) :: this
264 real(real12), dimension(..), intent(in) :: input
265 real(real12), dimension(:,:), optional, intent(in) :: addit_input
266 integer, optional, intent(in) :: addit_layer
267 integer, optional, intent(in) :: verbose
268 real(real12), dimension(:,:), allocatable :: output
269 end function predict_1d
270
271 !!-------------------------------------------------------------------------
272 !! update the learnable parameters of the network based on gradients
273 !!-------------------------------------------------------------------------
274 !! this = (T, io) network type
275 module subroutine update(this)
276 class(network_type), intent(inout) :: this
277 end subroutine update
278
279 !!-------------------------------------------------------------------------
280 !! reduce two networks down to one (i.e. add two networks - parallel)
281 !!-------------------------------------------------------------------------
282 !! this = (T, io) network type, resultant network of the reduction
283 !! source = (T, in) network type
284 module subroutine network_reduction(this, source)
285 class(network_type), intent(inout) :: this
286 type(network_type), intent(in) :: source
287 end subroutine network_reduction
288
289 !!-------------------------------------------------------------------------
290 !! copy a network
291 !!-------------------------------------------------------------------------
292 !! this = (T, io) network type, resultant network of the copy
293 !! source = (T, in) network type
294 module subroutine network_copy(this, source)
295 class(network_type), intent(inout) :: this
296 type(network_type), intent(in) :: source
297 end subroutine network_copy
298
299 !!-------------------------------------------------------------------------
300 !! get number of learnable parameters in the network
301 !!-------------------------------------------------------------------------
302 !! this = (T, in) network type
303 !! num_params = (I, out) number of parameters
304 pure module function get_num_params(this) result(num_params)
305 class(network_type), intent(in) :: this
306 integer :: num_params
307 end function get_num_params
308
309 !!-------------------------------------------------------------------------
310 !! get learnable parameters
311 !!-------------------------------------------------------------------------
312 !! this = (T, in) network type
313 !! params = (R, out) learnable parameters
314 pure module function get_params(this) result(params)
315 class(network_type), intent(in) :: this
316 real(real12), allocatable, dimension(:) :: params
317 end function get_params
318
319 !!-------------------------------------------------------------------------
320 !! set learnable parameters
321 !!-------------------------------------------------------------------------
322 !! this = (T, io) network type
323 !! params = (R, in) learnable parameters
324 !! verbose = (I, in, opt) verbosity level
325 module subroutine set_params(this, params)
326 class(network_type), intent(inout) :: this
327 real(real12), dimension(:), intent(in) :: params
328 end subroutine set_params
329
330 !!-------------------------------------------------------------------------
331 !! get gradients of learnable parameters
332 !!-------------------------------------------------------------------------
333 !! this = (T, in) network type
334 !! gradients = (R, out) gradients
335 pure module function get_gradients(this) result(gradients)
336 class(network_type), intent(in) :: this
337 real(real12), allocatable, dimension(:) :: gradients
338 end function get_gradients
339
340 !!-------------------------------------------------------------------------
341 !! set learnable parameter gradients
342 !!-------------------------------------------------------------------------
343 !! this = (T, io) network type
344 !! gradients = (R, in) gradients
345 !! verbose = (I, in, opt) verbosity level
346 module subroutine set_gradients(this, gradients)
347 class(network_type), intent(inout) :: this
348 real(real12), dimension(..), intent(in) :: gradients
349 end subroutine set_gradients
350
351 !!-------------------------------------------------------------------------
352 !! reset learnable parameter gradients
353 !!-------------------------------------------------------------------------
354 !! this = (T, io) network type
355 !! verbose = (I, in, opt) verbosity level
356 !!-------------------------------------------------------------------------
357 module subroutine reset_gradients(this)
358 class(network_type), intent(inout) :: this
359 end subroutine reset_gradients
360
361 !!-------------------------------------------------------------------------
362 !! forward pass
363 !!-------------------------------------------------------------------------
364 !! this = (T, io) network type
365 !! input = (R, in) input data
366 !! addit_input = (R, in, opt) additional input data
367 !! layer = (I, in, opt) layer to insert additional input data
368 pure module subroutine forward_1d(this, input, addit_input, layer)
369 class(network_type), intent(inout) :: this
370 real(real12), dimension(..), intent(in) :: input
371 real(real12), dimension(:,:), optional, intent(in) :: addit_input
372 integer, optional, intent(in) :: layer
373 end subroutine forward_1d
374
375 !!-------------------------------------------------------------------------
376 !! backward pass
377 !!-------------------------------------------------------------------------
378 !! this = (T, io) network type
379 !! output = (R, in) output data
380 pure module subroutine backward_1d(this, output)
381 class(network_type), intent(inout) :: this
382 real(real12), dimension(:,:), intent(in) :: output
383 end subroutine backward_1d
384 end interface
385
386 end module network
387 !!!#############################################################################
388