GCC Code Coverage Report


Directory: src/athena/
File: athena_concat_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__concat_layer
2 !! Module containing implementation of a concatenation layer
3 !!
4 !! This module implements a merge layer that concatenates multiple input
5 !! tensors along a specified dimension (features for 2D, channels for 3D).
6 !!
7 !! Mathematical operation:
8 !! output = [input_1 || input_2 || ... || input_N]
9 !!
10 !! where || denotes concatenation along the appropriate dimension.
11 !! Output size along concatenation dimension = sum of input sizes.
12 !! Gradients are split to corresponding input portions during backpropagation.
13 use coreutils, only: real32, stop_program
14 use athena__base_layer, only: merge_layer_type, base_layer_type
15 use diffstruc, only: array_type, operator(+)
16 use athena__diffstruc_extd, only: array_ptr_type, concat_layers
17 implicit none
18
19
20 private
21
22 public :: concat_layer_type
23 public :: read_concat_layer
24
25
26 type, extends(merge_layer_type) :: concat_layer_type
27 !! Type for concatenate layer with overloaded procedures
28 integer, dimension(:,:), allocatable :: io_map
29 !! I/O mapping for the layer
30 contains
31 procedure, pass(this) :: set_hyperparams => set_hyperparams_concat
32 !! Set the hyperparameters for concatenate layer
33 procedure, pass(this) :: init => init_concat
34 !! Initialise concatenate layer
35 procedure, pass(this) :: print_to_unit => print_to_unit_concat
36 !! Print the layer to a file
37 procedure, pass(this) :: read => read_concat
38 !! Read the layer from a file
39
40 procedure, pass(this) :: calc_input_shape => calc_input_shape_concat
41 !! Calculate input shape based on shapes of input layers
42
43 procedure, pass(this) :: combine => combine_concat
44 end type concat_layer_type
45
46 interface concat_layer_type
47 !! Interface for setting up the concatenate layer
48 module function layer_setup( &
49 input_layer_ids, input_rank, verbose &
50 ) result(layer)
51 !! Setup a concatenate layer
52 integer, dimension(:), intent(in) :: input_layer_ids
53 !! Input layer IDs
54 integer, optional, intent(in) :: input_rank
55 !! Input rank
56 integer, optional, intent(in) :: verbose
57 !! Verbosity level
58 type(concat_layer_type) :: layer
59 end function layer_setup
60 end interface concat_layer_type
61
62
63
64 contains
65
66 !###############################################################################
67 module function layer_setup( &
68 input_layer_ids, input_rank, verbose &
69 ) result(layer)
70 !! Setup a concatenate layer
71 implicit none
72
73 ! Arguments
74 integer, dimension(:), intent(in) :: input_layer_ids
75 !! Input layer IDs
76 integer, optional, intent(in) :: input_rank
77 !! Input rank
78 integer, optional, intent(in) :: verbose
79 !! Verbosity level
80
81 type(concat_layer_type) :: layer
82 !! Instance of the concatenate layer
83
84 ! Local variables
85 integer :: input_rank_ = 0
86 !! Input rank
87 integer :: verbose_ = 0
88 !! Verbosity level
89
90 if(present(verbose)) verbose_ = verbose
91
92
93 !---------------------------------------------------------------------------
94 ! Set hyperparameters
95 !---------------------------------------------------------------------------
96 if(present(input_rank))then
97 input_rank_ = input_rank
98 else
99 call stop_program( &
100 "input_rank or input_shape must be provided to concat layer" &
101 )
102 return
103 end if
104 call layer%set_hyperparams( &
105 input_layer_ids = input_layer_ids, &
106 input_rank = input_rank_, &
107 verbose = verbose_ &
108 )
109
110 end function layer_setup
111 !###############################################################################
112
113
114 !###############################################################################
115 subroutine set_hyperparams_concat( &
116 this, &
117 input_layer_ids, &
118 input_rank, &
119 verbose &
120 )
121 !! Set the hyperparameters for concatenate layer
122 implicit none
123
124 ! Arguments
125 class(concat_layer_type), intent(inout) :: this
126 !! Instance of the concatenate layer
127 integer, dimension(:), intent(in) :: input_layer_ids
128 !! Input layer IDs
129 integer, intent(in) :: input_rank
130 !! Input rank
131 integer, optional, intent(in) :: verbose
132 !! Verbosity level
133
134
135 this%name = "concatenate"
136 this%type = "merg"
137 this%merge_mode = 2 ! concatenate mode
138 this%input_layer_ids = input_layer_ids
139 this%input_rank = input_rank
140 this%output_rank = input_rank
141
142 end subroutine set_hyperparams_concat
143 !###############################################################################
144
145
146 !###############################################################################
147 subroutine init_concat(this, input_shape, verbose)
148 !! Initialise concatenate layer
149 implicit none
150
151 ! Arguments
152 class(concat_layer_type), intent(inout) :: this
153 !! Instance of the concatenate layer
154 integer, dimension(:), intent(in) :: input_shape
155 !! Input shape
156 integer, optional, intent(in) :: verbose
157 !! Verbosity level
158
159 ! Local variables
160 integer :: verbose_ = 0
161 !! Verbosity level
162
163
164 !---------------------------------------------------------------------------
165 ! Initialise optional arguments
166 !---------------------------------------------------------------------------
167 if(present(verbose)) verbose_ = verbose
168
169
170 !---------------------------------------------------------------------------
171 ! Initialise input shape
172 !---------------------------------------------------------------------------
173 this%input_rank = size(input_shape)
174 if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
175
176
177 !---------------------------------------------------------------------------
178 ! Initialise output shape
179 !---------------------------------------------------------------------------
180 this%output_shape = this%input_shape
181 this%output_rank = size(this%output_shape)
182
183
184 !---------------------------------------------------------------------------
185 ! Allocate arrays
186 !---------------------------------------------------------------------------
187 if(allocated(this%output)) deallocate(this%output)
188 allocate(this%output(1,1))
189
190 end subroutine init_concat
191 !###############################################################################
192
193
194 !##############################################################################!
195 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
196 !##############################################################################!
197
198
199 !###############################################################################
200 subroutine print_to_unit_concat(this, unit)
201 !! Print concatenate layer to unit
202 implicit none
203
204 ! Arguments
205 class(concat_layer_type), intent(in) :: this
206 !! Instance of the concatenate layer
207 integer, intent(in) :: unit
208 !! File unit
209
210 ! Local variables
211 integer :: i
212 !! Loop index
213 character(100) :: fmt
214
215
216 ! Write initial parameters
217 !---------------------------------------------------------------------------
218 write(unit,'(3X,"INPUT_RANK = ",I0)') this%input_rank
219 write(fmt,'("(3X,""INPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%input_shape)
220 write(unit,fmt) this%input_shape
221 write(fmt,'("(3X,""INPUT_LAYER_IDS ="",",I0,"(1X,I0))")') size(this%input_layer_ids)
222 write(unit,fmt) this%input_layer_ids
223
224 end subroutine print_to_unit_concat
225 !###############################################################################
226
227
228 !###############################################################################
229 subroutine read_concat(this, unit, verbose)
230 !! Read concatenate layer from file
231 use athena__tools_infile, only: assign_val, assign_vec, get_val
232 use coreutils, only: to_lower, to_upper, icount
233 implicit none
234
235 ! Arguments
236 class(concat_layer_type), intent(inout) :: this
237 !! Instance of the concatenate layer
238 integer, intent(in) :: unit
239 !! Unit number
240 integer, optional, intent(in) :: verbose
241 !! Verbosity level
242
243 ! Local variables
244 integer :: stat, verbose_ = 0
245 !! File status and verbosity level
246 integer :: itmp1 = 0
247 !! Temporary integer
248 integer :: input_rank = 0
249 !! Input rank
250 integer, dimension(:), allocatable :: input_shape, input_layer_ids
251 !! Input shape
252 character(256) :: buffer, tag, err_msg
253 !! Buffer, tag, and error message
254
255
256 ! Initialise optional arguments
257 !---------------------------------------------------------------------------
258 if(present(verbose)) verbose_ = verbose
259
260
261 ! Loop over tags in layer card
262 !---------------------------------------------------------------------------
263 tag_loop: do
264
265 ! Check for end of file
266 !------------------------------------------------------------------------
267 read(unit,'(A)',iostat=stat) buffer
268 if(stat.ne.0)then
269 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
270 to_upper(this%name)
271 call stop_program(err_msg)
272 return
273 end if
274 if(trim(adjustl(buffer)).eq."") cycle tag_loop
275
276 ! Check for end of layer card
277 !------------------------------------------------------------------------
278 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
279 backspace(unit)
280 exit tag_loop
281 end if
282
283 tag=trim(adjustl(buffer))
284 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
285
286 ! Read parameters from file
287 !------------------------------------------------------------------------
288 select case(trim(tag))
289 case("INPUT_SHAPE")
290 itmp1 = icount(get_val(buffer))
291 allocate(input_shape(itmp1), source=0)
292 call assign_vec(buffer, input_shape, itmp1)
293 case("INPUT_RANK")
294 call assign_val(buffer, input_rank, itmp1)
295 case("INPUT_LAYER_IDS")
296 itmp1 = icount(get_val(buffer))
297 allocate(input_layer_ids(itmp1), source=0)
298 call assign_vec(buffer, input_layer_ids, itmp1)
299 case default
300 ! Don't look for "e" due to scientific notation of numbers
301 ! ... i.e. exponent (E+00)
302 if(scan(to_lower(trim(adjustl(buffer))),&
303 'abcdfghijklmnopqrstuvwxyz').eq.0)then
304 cycle tag_loop
305 elseif(tag(:3).eq.'END')then
306 cycle tag_loop
307 end if
308 write(err_msg,'("Unrecognised line in input file: ",A)') &
309 trim(adjustl(buffer))
310 call stop_program(err_msg)
311 return
312 end select
313 end do tag_loop
314
315 if(allocated(input_shape))then
316 if(input_rank.eq.0)then
317 input_rank = size(input_shape)
318 elseif(input_rank.ne.size(input_shape))then
319 write(err_msg,'("input_rank (",I0,") does not match input_shape (",I0,")")') &
320 input_rank, size(input_shape)
321 call stop_program(err_msg)
322 return
323 end if
324 elseif(input_rank.eq.0)then
325 write(err_msg,'("input_rank must be provided if input_shape is not")')
326 call stop_program(err_msg)
327 return
328 end if
329
330
331 ! Set hyperparameters and initialise layer
332 !---------------------------------------------------------------------------
333 call this%set_hyperparams( &
334 input_layer_ids = input_layer_ids, &
335 input_rank = input_rank, &
336 verbose = verbose_ &
337 )
338 call this%init(input_shape = input_shape)
339
340
341 ! Check for end of layer card
342 !---------------------------------------------------------------------------
343 read(unit,'(A)') buffer
344 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
345 write(0,*) trim(adjustl(buffer))
346 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
347 call stop_program(err_msg)
348 return
349 end if
350
351 end subroutine read_concat
352 !###############################################################################
353
354
355 !###############################################################################
356 function read_concat_layer(unit, verbose) result(layer)
357 !! Read concatenate layer from file and return layer
358 implicit none
359
360 ! Arguments
361 integer, intent(in) :: unit
362 !! Unit number
363 integer, optional, intent(in) :: verbose
364 !! Verbosity level
365 class(base_layer_type), allocatable :: layer
366 !! Instance of the concatenate layer
367
368 ! Local variables
369 integer :: verbose_ = 0
370 !! Verbosity level
371
372 if(present(verbose)) verbose_ = verbose
373 allocate(layer, source=concat_layer_type(input_layer_ids=[0,0]))
374 call layer%read(unit, verbose=verbose_)
375
376 end function read_concat_layer
377 !###############################################################################
378
379
380 !##############################################################################!
381 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
382 !##############################################################################!
383
384
385 !###############################################################################
386 function calc_input_shape_concat(this, input_shapes) result(input_shape)
387 !! Calculate input shape based on shapes of input layers
388 implicit none
389
390 ! Arguments
391 class(concat_layer_type), intent(in) :: this
392 !! Instance of the layer
393 integer, dimension(:,:), intent(in) :: input_shapes
394 !! Input shapes
395 integer, allocatable, dimension(:) :: input_shape
396 !! Calculated input shape
397
398
399 input_shape = sum(input_shapes, dim=2)
400
401 end function calc_input_shape_concat
402 !###############################################################################
403
404
405 !##############################################################################!
406 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
407 !##############################################################################!
408
409
410 !###############################################################################
411 subroutine combine_concat(this, input_list)
412 !! Forward propagation for 2D input
413 implicit none
414
415 ! Arguments
416 class(concat_layer_type), intent(inout) :: this
417 !! Instance of the concatenate layer
418 type(array_ptr_type), dimension(:), intent(in) :: input_list
419 !! Input values
420
421 ! Local variables
422 integer :: i, j, s
423 !! Loop index
424 type(array_type), pointer :: ptr
425 !! Pointer array
426
427
428 if(allocated(this%output))then
429 if(any(shape(this%output).ne.shape(input_list(1)%array)))then
430 deallocate(this%output)
431 allocate(this%output( &
432 size(input_list(1)%array,1), &
433 size(input_list(1)%array,2) &
434 ))
435 end if
436 else
437 allocate(this%output( &
438 size(input_list(1)%array,1), &
439 size(input_list(1)%array,2) &
440 ))
441 end if
442
443 do s = 1, size(input_list(1)%array, 2)
444 index_loop: do i = 1, size(input_list(1)%array, 1)
445 do j = 1, size(input_list,1)
446 if(.not.input_list(j)%array(i,s)%allocated) cycle index_loop
447 end do
448 ptr => concat_layers(input_list, i, s, dim = 1)
449 call this%output(i,s)%zero_grad()
450 call this%output(i,s)%assign_and_deallocate_source(ptr)
451 this%output(i,s)%is_temporary = .false.
452 end do index_loop
453 end do
454
455 end subroutine combine_concat
456 !###############################################################################
457
458 end module athena__concat_layer
459