GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_base_layer_sub.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 110 191 57.6%
Functions: 0 0 -%
Branches: 295 1134 26.0%

Line Branch Exec Source
1 submodule(athena__base_layer) athena__base_layer_submodule
2 !! Submodule containing the implementation of the base layer types
3 !!
4 !! This submodule contains the implementation of the base layer types
5 !! used in the ATHENA library. The base layer types are the abstract
6 !! types from which all other layer types are derived. The submodule
7 !! contains the implementation of the procedures that are common to
8 !! all layer types, such as setting the input shape, getting the
9 !! number of parameters, and printing the layer to a file.
10 !!
11 !! The following procedures are based on code from the neural-fortran library
12 !! https://github.com/modern-fortran/neural-fortran/blob/main/src/nf/nf_layer.f90
13 !! procedures:
14 !! - get_num_params*
15 !! - get_params*
16 !! - set_params*
17 !! - get_gradients*
18 !! - set_gradients*
19 use coreutils, only: stop_program, print_warning
20
21 contains
22
23 !###############################################################################
24 5 module function get_attributes_base(this) result(attributes)
25 !! Get the attributes of the layer (for ONNX export)
26 implicit none
27
28 ! Arguments
29 class(base_layer_type), intent(in) :: this
30 !! Instance of the layer
31 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
32 !! Attributes of the layer
33
34 ! Allocate attributes array
35
8/24
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
5 allocate(attributes(0))
36 ! attributes(0)%name = this%name
37 ! attributes(0)%val = this%get_type_name()
38 ! attributes(0)%type = ""
39
40 5 end function get_attributes_base
41 !-------------------------------------------------------------------------------
42 2 module function get_attributes_conv(this) result(attributes)
43 !! Get the attributes of a convolutional layer (for ONNX export)
44 implicit none
45
46 ! Arguments
47 class(conv_layer_type), intent(in) :: this
48 !! Instance of the layer
49 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
50 !! Attributes of the layer
51
52 ! Local variables
53 character(256) :: buffer, fmt
54 !! Buffer for formatting
55
56 ! Allocate attributes array
57
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 6 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 6 times.
8 allocate(attributes(3))
58
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(1)%name = "kernel_shape"
59 2 write(fmt,'("(",I0,"(1X,I0))")') size(this%knl)
60 2 write(buffer,fmt) this%knl
61
6/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
2 attributes(1)%val = trim(adjustl(buffer))
62
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(1)%type = "ints"
63
64
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(2)%name = "strides"
65 2 write(fmt,'("(",I0,"(1X,I0))")') size(this%stp)
66 2 write(buffer,fmt) this%stp
67
6/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
2 attributes(2)%val = trim(adjustl(buffer))
68
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(2)%type = "ints"
69
70
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(3)%name = "dilations"
71 2 write(fmt,'("(",I0,"(1X,I0))")') size(this%dil)
72 2 write(buffer,fmt) this%dil
73
6/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
2 attributes(3)%val = trim(adjustl(buffer))
74
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(3)%type = "ints"
75
76 2 end function get_attributes_conv
77 !-------------------------------------------------------------------------------
78 2 module function get_attributes_pool(this) result(attributes)
79 !! Get the attributes of a pooling layer (for ONNX export)
80 implicit none
81
82 ! Arguments
83 class(pool_layer_type), intent(in) :: this
84 !! Instance of the layer
85 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
86 !! Attributes of the layer
87
88 ! Local variables
89 character(256) :: buffer, fmt
90 !! Buffer for formatting
91
92 ! Allocate attributes array
93
13/24
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 4 times.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 4 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 4 times.
6 allocate(attributes(2))
94
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(1)%name = "kernel_shape"
95 2 write(fmt,'("(",I0,"(1X,I0))")') size(this%pool)
96 2 write(buffer,fmt) this%pool
97
6/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
2 attributes(1)%val = trim(adjustl(buffer))
98
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(1)%type = "ints"
99
100
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(2)%name = "strides"
101 2 write(fmt,'("(",I0,"(1X,I0))")') size(this%strd)
102 2 write(buffer,fmt) this%strd
103
6/14
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
2 attributes(2)%val = trim(adjustl(buffer))
104
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
2 attributes(2)%type = "ints"
105
106 2 end function get_attributes_pool
107 !-------------------------------------------------------------------------------
108 module function get_attributes_batch(this) result(attributes)
109 !! Get the attributes of a batch normalisation layer (for ONNX export)
110 implicit none
111
112 ! Arguments
113 class(batch_layer_type), intent(in) :: this
114 !! Instance of the layer
115 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
116 !! Attributes of the layer
117
118 ! Local variables
119 character(256) :: buffer, fmt
120 !! Buffer for formatting
121
122 ! Allocate attributes array
123 allocate(attributes(4))
124 attributes(1)%name = "epsilon"
125 write(buffer,'("(",F0.6,")")') this%epsilon
126 attributes(1)%val = trim(adjustl(buffer))
127 attributes(1)%type = "float"
128
129 attributes(2)%name = "momentum"
130 write(buffer,'("(",F0.6,")")') this%momentum
131 attributes(2)%val = trim(adjustl(buffer))
132 attributes(2)%type = "float"
133
134 attributes(3)%name = "scale"
135 write(fmt,'("(",I0,"(1X,F0.6))")') this%num_channels
136 write(buffer,fmt) this%params(1)%val(1:this%num_channels,1)
137 attributes(3)%val = trim(adjustl(buffer))
138 attributes(3)%type = "float"
139
140 attributes(4)%name = "B"
141 write(fmt,'("(",I0,"(1X,F0.6))")') this%num_channels
142 write(buffer,fmt) this%params(1)%val(this%num_channels+1:2*this%num_channels,1)
143 attributes(4)%val = trim(adjustl(buffer))
144 attributes(4)%type = "float"
145
146 end function get_attributes_batch
147 !###############################################################################
148
149
150 !###############################################################################
151 module subroutine build_from_onnx_base( &
152 this, node, initialisers, value_info, verbose &
153 )
154 !! Build layer from ONNX node and initialiser
155 implicit none
156
157 ! Arguments
158 class(base_layer_type), intent(inout) :: this
159 !! Instance of the layer
160 type(onnx_node_type), intent(in) :: node
161 !! ONNX node
162 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
163 !! ONNX initialisers
164 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
165 !! ONNX value info
166 integer, intent(in) :: verbose
167 !! Verbosity level
168
169 write(0,*) "build_from_onnx_base: " // &
170 trim(this%name) // " layer cannot be built from ONNX"
171
172 end subroutine build_from_onnx_base
173 !###############################################################################
174
175
176 !###############################################################################
177 module subroutine emit_onnx_nodes_base( &
178 this, prefix, &
179 nodes, num_nodes, max_nodes, &
180 inits, num_inits, max_inits, &
181 input_name, is_last_layer, format &
182 )
183 !! Default implementation: no-op (standard layers are handled by write_onnx)
184 implicit none
185
186 ! Arguments
187 class(base_layer_type), intent(in) :: this
188 !! Instance of the layer
189 character(*), intent(in) :: prefix
190 !! Prefix for node names
191 type(onnx_node_type), intent(inout), dimension(:) :: nodes
192 !! ONNX nodes
193 integer, intent(inout) :: num_nodes
194 !! Number of ONNX nodes
195 integer, intent(in) :: max_nodes
196 !! Maximum number of ONNX nodes
197 type(onnx_initialiser_type), intent(inout), dimension(:) :: inits
198 !! ONNX initialisers
199 integer, intent(inout) :: num_inits
200 !! Number of ONNX initialisers
201 integer, intent(in) :: max_inits
202 !! Maximum number of ONNX initialisers
203 character(*), optional, intent(in) :: input_name
204 !! Name of the input tensor from the previous layer
205 logical, optional, intent(in) :: is_last_layer
206 !! Whether this is the last non-input layer
207 integer, optional, intent(in) :: format
208 !! Export format selector
209
210 ! Default: do nothing. Standard layers are handled directly by write_onnx.
211 end subroutine emit_onnx_nodes_base
212 !###############################################################################
213
214
215 !###############################################################################
216 module subroutine emit_onnx_graph_inputs_base( &
217 this, prefix, &
218 graph_inputs, num_inputs &
219 )
220 !! Default implementation: no-op (standard layers don't add graph inputs)
221 implicit none
222
223 ! Arguments
224 class(base_layer_type), intent(in) :: this
225 !! Instance of the layer
226 character(*), intent(in) :: prefix
227 !! Prefix for input names
228 type(onnx_tensor_type), intent(inout), dimension(:) :: graph_inputs
229 !! ONNX graph inputs
230 integer, intent(inout) :: num_inputs
231 !! Number of ONNX graph inputs
232
233 ! Default: do nothing. Standard input layers are handled directly.
234 end subroutine emit_onnx_graph_inputs_base
235 !###############################################################################
236
237
238 !###############################################################################
239 module subroutine set_rank_base(this, input_rank, output_rank)
240 !! Set the input and output ranks of the layer
241 implicit none
242
243 ! Arguments
244 class(base_layer_type), intent(inout) :: this
245 !! Instance of the layer
246 integer, intent(in) :: input_rank
247 !! Input rank
248 integer, intent(in) :: output_rank
249 !! Output rank
250
251 !---------------------------------------------------------------------------
252 ! Set input and output ranks
253 !---------------------------------------------------------------------------
254 call stop_program("set_rank_base: this layer cannot have its rank set")
255
256 end subroutine set_rank_base
257 !###############################################################################
258
259
260 !###############################################################################
261
1/2
✓ Branch 0 taken 420 times.
✗ Branch 1 not taken.
420 module subroutine set_shape_base(this, input_shape)
262 !! Set the input shape of the layer
263 implicit none
264
265 ! Arguments
266 class(base_layer_type), intent(inout) :: this
267 !! Instance of the layer
268 integer, dimension(:), intent(in) :: input_shape
269 !! Input shape
270 character(len=100) :: err_msg
271 !! Error message
272
273 !---------------------------------------------------------------------------
274 ! initialise input shape
275 !---------------------------------------------------------------------------
276
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 420 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 420 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 420 times.
✓ Branch 9 taken 420 times.
✗ Branch 10 not taken.
420 if(size(input_shape,dim=1).eq.this%input_rank)then
277
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 420 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 420 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 420 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 420 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 420 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 847 times.
✓ Branch 16 taken 420 times.
1267 this%input_shape = input_shape
278 else
279 write(err_msg,'("Invalid size of input_shape in ",A,&
280 &" expected (",I0,"), got (",I0,")")') &
281 trim(this%name), this%input_rank, size(input_shape,dim=1)
282 call stop_program(err_msg)
283 return
284 end if
285
286 end subroutine set_shape_base
287 !###############################################################################
288
289
290 !###############################################################################
291 61 module subroutine extract_output_base(this, output)
292 !! Get the output of the layer
293 implicit none
294
295 ! Arguments
296 class(base_layer_type), intent(in) :: this
297 !! Instance of the layer
298 real(real32), allocatable, dimension(..), intent(out) :: output
299 !! Output of the Layer
300
301
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 122 times.
✓ Branch 2 taken 122 times.
✓ Branch 3 taken 61 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 61 times.
183 if(size(this%output).gt.1)then
302 call print_warning("extract_output_base: output has more than one"&
303 &" sample, cannot extract")
304 return
305 end if
306
307
5/10
✗ Branch 0 not taken.
✓ Branch 1 taken 61 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 61 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 61 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 61 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 61 times.
61 call this%output(1,1)%extract(output)
308
309 end subroutine extract_output_base
310 !###############################################################################
311
312
313 !###############################################################################
314 pure module function get_num_params_base(this) result(num_params)
315 !! Get the number of parameters in the layer
316 implicit none
317
318 ! Arguments
319 class(base_layer_type), intent(in) :: this
320 !! Instance of the layer
321 integer :: num_params
322 !! Number of parameters
323
324 ! No parameters in the base layer
325 num_params = 0
326
327 end function get_num_params_base
328 !-------------------------------------------------------------------------------
329 67 pure module function get_num_params_conv(this) result(num_params)
330 !! Get the number of parameters in convolutional layer
331 implicit none
332
333 ! Arguments
334 class(conv_layer_type), intent(in) :: this
335 !! Instance of the layer
336 integer :: num_params
337 !! Number of parameters
338
339 ! num_filters x num_channels x kernel_size + num_biases
340 ! num_biases = num_filters
341 268 num_params = this%num_filters * this%num_channels * product(this%knl) + &
342
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 67 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 67 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 67 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 67 times.
✓ Branch 12 taken 156 times.
✓ Branch 13 taken 67 times.
223 this%num_filters
343
344 67 end function get_num_params_conv
345 !-------------------------------------------------------------------------------
346 18 pure module function get_num_params_batch(this) result(num_params)
347 !! Get the number of parameters in batch normalisation layer
348 implicit none
349
350 ! Arguments
351 class(batch_layer_type), intent(in) :: this
352 !! Instance of the layer
353 integer :: num_params
354 !! Number of parameters
355
356 ! num_filters x num_channels x kernel_size + num_biases
357 ! num_biases = num_filters
358 18 num_params = 2 * this%num_channels
359
360 18 end function get_num_params_batch
361 !###############################################################################
362
363
364 !###############################################################################
365 module subroutine forward_base(this, input)
366 !! Forward pass for the layer
367 implicit none
368
369 ! Arguments
370 class(base_layer_type), intent(inout) :: this
371 !! Instance of the layer
372 class(array_type), dimension(:,:), intent(in) :: input
373 !! Input data
374
375 ! Local variables
376 integer :: i, j
377 !! Loop indices
378
379 do i = 1, size(input, 1)
380 do j = 1, size(input, 2)
381 if(.not.input(i,j)%allocated)then
382 call stop_program('Input to input layer not allocated')
383 return
384 end if
385 this%output(i,j) = input(i,j)
386 end do
387 end do
388
389 end subroutine forward_base
390 !-------------------------------------------------------------------------------
391 module function forward_eval_base(this, input) result(output)
392 !! Forward pass of layer and return output for evaluation
393 implicit none
394
395 ! Arguments
396 class(base_layer_type), intent(inout), target :: this
397 !! Instance of the layer
398 class(array_type), dimension(:,:), intent(in) :: input
399 !! Input data
400 type(array_type), pointer :: output(:,:)
401 !! Output data
402
403 call this%forward(input)
404 output => this%output
405 end function forward_eval_base
406 !###############################################################################
407
408
409 !###############################################################################
410 module subroutine set_graph_base(this, graph)
411 !! Set the graph structure of the input data
412 implicit none
413
414 ! Arguments
415 class(base_layer_type), intent(inout) :: this
416 !! Instance of the layer
417 type(graph_type), dimension(:), intent(in) :: graph
418 !! Graph structure of input data
419
420 ! Local variables
421 integer :: s
422 !! Loop index
423
424 if(allocated(this%graph))then
425 if(size(this%graph).ne.size(graph))then
426 deallocate(this%graph)
427 allocate(this%graph(size(graph)))
428 end if
429 else
430 allocate(this%graph(size(graph)))
431 end if
432 do s = 1, size(graph)
433 this%graph(s)%adj_ia = graph(s)%adj_ia
434 this%graph(s)%adj_ja = graph(s)%adj_ja
435 this%graph(s)%edge_weights = graph(s)%edge_weights
436 this%graph(s)%num_edges = graph(s)%num_edges
437 this%graph(s)%num_vertices = graph(s)%num_vertices
438 end do
439
440 end subroutine set_graph_base
441 !###############################################################################
442
443
444 !###############################################################################
445 module subroutine nullify_graph_base(this)
446 !! Nullify the forward pass data of the layer to free memory
447 implicit none
448
449 ! Arguments
450 class(base_layer_type), intent(inout) :: this
451 !! Instance of the layer
452
453 ! Local variables
454 integer :: i, j
455 !! Loop indices
456
457 do i = 1, size(this%output,1)
458 do j = 1, size(this%output,2)
459 call this%output(i,j)%nullify_graph()
460 end do
461 end do
462
463 end subroutine nullify_graph_base
464 !###############################################################################
465
466
467 !###############################################################################
468 11 module subroutine reduce_learnable(this, input)
469 !! Merge two learnable layers via summation
470 implicit none
471
472 ! Arguments
473 class(learnable_layer_type), intent(inout) :: this
474 !! Instance of the layer
475 class(learnable_layer_type), intent(in) :: input
476 !! Instance of a layer
477
478 ! Local variables
479 integer :: i
480 !! Loop index
481
482
1/2
✓ Branch 0 taken 11 times.
✗ Branch 1 not taken.
11 if(allocated(this%params).and.allocated(input%params))then
483
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
11 if(size(this%params).ne.size(input%params))then
484 call stop_program("reduce_learnable: incompatible parameter sizes")
485 return
486 end if
487
2/2
✓ Branch 0 taken 28 times.
✓ Branch 1 taken 11 times.
39 do i = 1, size(this%params,1)
488
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 28 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 28 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 28 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 28 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 28 times.
28 this%params(i) = this%params(i) + input%params(i)
489
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 28 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 28 times.
56 if(associated(this%params(i)%grad).and.&
490
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 28 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 28 times.
39 associated(input%params(i)%grad))then
491 this%params(i)%grad = this%params(i)%grad + &
492 input%params(i)%grad
493 end if
494 end do
495 else
496 call stop_program("reduce_learnable: unallocated parameter arrays")
497 return
498 end if
499
500 end subroutine reduce_learnable
501 !###############################################################################
502
503
504 !###############################################################################
505 12 module function add_learnable(a, b) result(output)
506 !! Add two learnable layers together
507 implicit none
508
509 ! Arguments
510 class(learnable_layer_type), intent(in) :: a, b
511 !! Instances of layers
512 class(learnable_layer_type), allocatable :: output
513 !! Output layer
514
515 ! Local variables
516 integer :: i
517 !! Loop index
518
519
3/10
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 12 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
12 output = a
520
1/2
✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
12 if(allocated(a%params).and.allocated(b%params))then
521
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
12 if(size(a%params).ne.size(b%params))then
522 call stop_program("add_learnable: incompatible parameter sizes")
523 return
524 end if
525
2/2
✓ Branch 0 taken 27 times.
✓ Branch 1 taken 12 times.
39 do i = 1, size(a%params,1)
526
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
27 output%params(i)%grad => null()
527
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 27 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 27 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 27 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 27 times.
27 output%params(i) = a%params(i) + b%params(i)
528
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
✓ Branch 6 taken 3 times.
✓ Branch 7 taken 24 times.
54 if(associated(a%params(i)%grad).and.&
529
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 27 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 27 times.
39 associated(b%params(i)%grad))then
530
5/10
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
3 allocate(output%params(i)%grad)
531 12 output%params(i)%grad = a%params(i)%grad + &
532
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
3 b%params(i)%grad
533 end if
534 end do
535 else
536 call stop_program("add_learnable: unallocated parameter arrays")
537 return
538 end if
539
540 12 end function add_learnable
541 !###############################################################################
542
543
544 !###############################################################################
545
1/2
✓ Branch 0 taken 17 times.
✗ Branch 1 not taken.
17 pure module function get_params(this) result(params)
546 !! Get the learnable parameters of the layer
547 !!
548 !! This function returns the learnable parameters of the layer
549 !! as a single array.
550 !! This has been modified from the neural-fortran library
551 implicit none
552
553 ! Arguments
554 class(learnable_layer_type), intent(in) :: this
555 !! Instance of the layer
556 real(real32), dimension(this%num_params) :: params
557 !! Learnable parameters
558
559 ! Local variables
560 integer :: i, start_idx, end_idx
561 !! Loop indices
562
563 17 start_idx = 0
564 17 end_idx = 0
565
2/2
✓ Branch 0 taken 41 times.
✓ Branch 1 taken 17 times.
58 do i = 1, size(this%params)
566 41 start_idx = end_idx + 1
567
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 41 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 41 times.
41 end_idx = start_idx + size(this%params(i)%val,1) - 1
568
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 41 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 41 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 41 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 41 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 41 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 41 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 41 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 41 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 41 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 41 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 41 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 41 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 41 times.
✓ Branch 39 taken 4880 times.
✓ Branch 40 taken 41 times.
4938 params(start_idx:end_idx) = this%params(i)%val(:,1)
569 end do
570
571
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 17 times.
17 end function get_params
572 !###############################################################################
573
574
575 !###############################################################################
576 11 module subroutine set_params(this, params)
577 !! Set the learnable parameters of the layer
578 !!
579 !! This function sets the learnable parameters of the layer
580 !! from a single array.
581 !! This has been modified from the neural-fortran library
582 implicit none
583
584 ! Arguments
585 class(learnable_layer_type), intent(inout) :: this
586 !! Instance of the layer
587 real(real32), dimension(this%num_params), intent(in) :: params
588 !! Learnable parameters
589
590 ! Local variables
591 integer :: i, start_idx, end_idx
592 !! Loop indices
593
594
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
11 if(.not.allocated(this%params))then
595 call stop_program("set_params: params not allocated")
596 return
597 end if
598 11 start_idx = 0
599 11 end_idx = 0
600
2/2
✓ Branch 0 taken 25 times.
✓ Branch 1 taken 11 times.
36 do i = 1, size(this%params)
601 25 start_idx = end_idx + 1
602
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
25 end_idx = start_idx + size(this%params(i)%val,1) - 1
603
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 25 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 25 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 25 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 25 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 25 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 25 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 25 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 25 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 25 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 25 times.
✓ Branch 39 taken 2556 times.
✓ Branch 40 taken 25 times.
2581 this%params(i)%val(:,1) = params(start_idx:end_idx)
604 end do
605
606 end subroutine set_params
607 !###############################################################################
608
609
610 !###############################################################################
611
1/2
✓ Branch 0 taken 35 times.
✗ Branch 1 not taken.
35 pure module function get_gradients(this, clip_method) result(gradients)
612 !! Get the gradients of the layer
613 !!
614 !! This function returns the gradients of the layer as a single array.
615 !! This has been modified from the neural-fortran library
616 use athena__clipper, only: clip_type
617 implicit none
618
619 ! Arguments
620 class(learnable_layer_type), intent(in) :: this
621 !! Instance of the layer
622 type(clip_type), optional, intent(in) :: clip_method
623 !! Method to clip the gradients
624 real(real32), dimension(this%num_params) :: gradients
625 !! Gradients of the layer
626
627 ! Local variables
628 integer :: i, start_idx, end_idx
629 !! Loop indices
630
631
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 35 times.
35 if(.not.allocated(this%params))then
632 return
633 end if
634 35 start_idx = 0
635 35 end_idx = 0
636
2/2
✓ Branch 0 taken 66 times.
✓ Branch 1 taken 35 times.
101 do i = 1, size(this%params)
637 66 start_idx = end_idx + 1
638
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
66 end_idx = start_idx + size(this%params(i)%val,1) - 1
639
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 66 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 66 times.
✓ Branch 6 taken 16 times.
✓ Branch 7 taken 50 times.
101 if(.not.associated(this%params(i)%grad))then
640
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✓ Branch 12 taken 2324 times.
✓ Branch 13 taken 16 times.
2340 gradients(start_idx:end_idx) = 0._real32
641 else
642
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 50 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 50 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 50 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 50 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 50 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 50 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 50 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 50 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 50 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 50 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 50 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 50 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 50 times.
✓ Branch 39 taken 7338 times.
✓ Branch 40 taken 50 times.
7388 gradients(start_idx:end_idx) = this%params(i)%grad%val(:,1)
643 end if
644 end do
645
646
1/8
✗ Branch 0 not taken.
✓ Branch 1 taken 35 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
35 if(present(clip_method)) call clip_method%apply(size(gradients),gradients)
647
648
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 35 times.
35 end function get_gradients
649 !###############################################################################
650
651
652 !###############################################################################
653 23 module subroutine set_gradients(this, gradients)
654 !! Set the gradients of the layer
655 !!
656 !! This function sets the gradients of the layer from a single array.
657 !! This has been modified from the neural-fortran library
658 implicit none
659
660 ! Arguments
661 class(learnable_layer_type), intent(inout) :: this
662 !! Instance of the layer
663 real(real32), dimension(..), intent(in) :: gradients
664 !! Gradients of the layer
665
666 ! Local variables
667 integer :: i, start_idx, end_idx
668 !! Loop indices
669
670 23 start_idx = 0
671 23 end_idx = 0
672 select rank(gradients)
673 rank(0)
674
2/2
✓ Branch 0 taken 25 times.
✓ Branch 1 taken 15 times.
40 do i = 1, size(this%params)
675
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
✓ Branch 6 taken 9 times.
✓ Branch 7 taken 16 times.
25 if(.not.associated(this%params(i)%grad))then
676
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
9 this%params(i)%grad => this%params(i)%create_result()
677 end if
678
10/18
✗ Branch 0 not taken.
✓ Branch 1 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 25 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 25 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 25 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 25 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 25 times.
✓ Branch 24 taken 2368 times.
✓ Branch 25 taken 25 times.
2408 this%params(i)%grad%val(:,1) = gradients
679 end do
680 rank(1)
681
2/2
✓ Branch 0 taken 16 times.
✓ Branch 1 taken 8 times.
24 do i = 1, size(this%params)
682
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✓ Branch 6 taken 13 times.
✓ Branch 7 taken 3 times.
16 if(.not.associated(this%params(i)%grad))then
683
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
13 this%params(i)%grad => this%params(i)%create_result()
684 end if
685 16 start_idx = end_idx + 1
686
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
16 end_idx = start_idx + size(this%params(i)%val,1) - 1
687
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 16 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 16 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 16 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 16 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 16 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 16 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 16 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 16 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 16 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 16 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 16 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 16 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 16 times.
✓ Branch 39 taken 2250 times.
✓ Branch 40 taken 16 times.
2274 this%params(i)%grad%val(:,1) = gradients(start_idx:end_idx)
688 end do
689 end select
690
691 23 end subroutine set_gradients
692 !###############################################################################
693
694
6/8
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 15 times.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
✓ Branch 5 taken 8 times.
✓ Branch 6 taken 8 times.
✗ Branch 7 not taken.
23 end submodule athena__base_layer_submodule
695