GCC Code Coverage Report


Directory: src/athena/
File: athena_reshape_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__reshape_layer
2 !! Module containing implementation of a reshape layer
3 !!
4 !! This module implements a general reshape layer that can transform tensors
5 !! between arbitrary shapes while preserving the total number of elements.
6 !! Unlike flatten (which only converts to 1D), reshape allows any target shape.
7 !!
8 !! Mathematical operation:
9 !! Reshape: (d1, d2, ..., dn) -> (d1', d2', ..., dm')
10 !! where: d1 * d2 * ... * dn = d1' * d2' * ... * dm'
11 !!
12 !! Examples:
13 !! - (28, 28) -> (784) [flatten]
14 !! - (784) -> (28, 28) [unflatten]
15 !! - (64, 32, 32) -> (64, 1024) [spatial to sequence]
16 !! - (100, 50) -> (10, 10, 50) [add spatial dimension]
17 !!
18 !! Properties:
19 !! - No learnable parameters (pure reshape operation)
20 !! - Preserves all information (bijective mapping)
21 !! - No computation beyond memory reorganization
22 !! - Gradients flow unchanged (chain rule applies directly)
23 use coreutils, only: real32, stop_program
24 use athena__base_layer, only: base_layer_type
25 use diffstruc, only: array_type, reshape
26 use athena__misc_types, only: &
27 onnx_node_type, onnx_initialiser_type, onnx_tensor_type
28 implicit none
29
30
31 private
32
33 public :: reshape_layer_type
34 public :: read_reshape_layer, create_from_onnx_reshape_layer
35
36
37 type, extends(base_layer_type) :: reshape_layer_type
38 !! Type for reshape layer with overloaded procedures
39 contains
40 procedure, pass(this) :: set_hyperparams => set_hyperparams_reshape
41 !! Set hyperparameters for reshape layer
42 procedure, pass(this) :: init => init_reshape
43 !! Initialise reshape layer
44 procedure, pass(this) :: print_to_unit => print_to_unit_reshape
45 !! Print reshape layer to unit
46 procedure, pass(this) :: read => read_reshape
47 !! Read reshape layer from file
48 procedure, pass(this) :: build_from_onnx => build_from_onnx_reshape
49 !! Build reshape layer from ONNX node and initialisers
50
51 procedure, pass(this) :: forward => forward_reshape
52 !! Forward propagation derived type handler
53
54 end type reshape_layer_type
55
56 interface reshape_layer_type
57 !! Interface for setting up the reshape layer
58 module function layer_setup( &
59 output_shape, input_shape, verbose &
60 ) result(layer)
61 !! Set up the reshape layer
62 integer, dimension(:), intent(in) :: output_shape
63 !! Target output shape (excluding batch dimension)
64 integer, dimension(:), optional, intent(in) :: input_shape
65 !! Input shape (excluding batch dimension)
66 integer, optional, intent(in) :: verbose
67 !! Verbosity level
68 type(reshape_layer_type) :: layer
69 !! Instance of the reshape layer
70 end function layer_setup
71 end interface reshape_layer_type
72
73
74
75 contains
76
77 !###############################################################################
78 module function layer_setup( &
79 output_shape, input_shape, verbose &
80 ) result(layer)
81 !! Set up the reshape layer
82 implicit none
83
84 ! Arguments
85 integer, dimension(:), intent(in) :: output_shape
86 !! Target output shape (excluding batch dimension)
87 integer, dimension(:), optional, intent(in) :: input_shape
88 !! Input shape (excluding batch dimension)
89 integer, optional, intent(in) :: verbose
90 !! Verbosity level
91
92 type(reshape_layer_type) :: layer
93 !! Instance of the reshape layer
94
95 ! Local variables
96 integer :: verbose_ = 0
97 !! Verbosity level
98
99 if(present(verbose)) verbose_ = verbose
100
101
102 !---------------------------------------------------------------------------
103 ! Set hyperparameters
104 !---------------------------------------------------------------------------
105 call layer%set_hyperparams(output_shape, verbose_)
106
107
108 !---------------------------------------------------------------------------
109 ! Initialise layer
110 !---------------------------------------------------------------------------
111 if(present(input_shape)) call layer%init(input_shape, verbose_)
112
113 end function layer_setup
114 !###############################################################################
115
116
117 !###############################################################################
118 subroutine set_hyperparams_reshape(this, output_shape, verbose)
119 !! Set hyperparameters for reshape layer
120 implicit none
121
122 ! Arguments
123 class(reshape_layer_type), intent(inout) :: this
124 !! Instance of the reshape layer
125 integer, dimension(:), intent(in) :: output_shape
126 !! Output rank
127 integer, optional, intent(in) :: verbose
128 !! Verbosity level
129
130 ! Local variables
131 integer :: verbose_ = 0
132 !! Verbosity level
133
134 if(present(verbose)) verbose_ = verbose
135
136 this%type = "rshp"
137 this%name = "reshape"
138 this%input_rank = 0
139 this%output_shape = output_shape
140 this%output_rank = size(output_shape)
141
142 if(verbose_ .gt. 0) write(*,'(" Setting up reshape layer")')
143
144 end subroutine set_hyperparams_reshape
145 !###############################################################################
146
147
148 !###############################################################################
149 subroutine init_reshape(this, input_shape, verbose)
150 !! Initialise reshape layer
151 implicit none
152
153 ! Arguments
154 class(reshape_layer_type), intent(inout) :: this
155 !! Instance of the reshape layer
156 integer, dimension(:), intent(in) :: input_shape
157 !! Input shape
158 integer, optional, intent(in) :: verbose
159 !! Verbosity level
160
161 ! Local variables
162 integer :: verbose_ = 0
163 !! Verbosity level
164 integer :: input_size, output_size
165 !! Total number of elements
166 integer :: i
167 !! Loop index
168
169 if(present(verbose)) verbose_ = verbose
170
171 !---------------------------------------------------------------------------
172 ! Set input shape
173 !---------------------------------------------------------------------------
174 this%input_rank = size(input_shape)
175 if(allocated(this%input_shape)) deallocate(this%input_shape)
176 allocate(this%input_shape, source=input_shape)
177
178
179 !---------------------------------------------------------------------------
180 ! Validate reshape compatibility
181 !---------------------------------------------------------------------------
182 input_size = product(input_shape)
183
184 output_size = product(this%output_shape)
185
186 if(input_size .ne. output_size)then
187 write(*,'("ERROR: Reshape layer - incompatible shapes")')
188 write(*,'(" Input shape has ",I0," elements")') input_size
189 write(*,'(" Output shape has ",I0," elements")') output_size
190 call stop_program("Reshape layer shape mismatch")
191 end if
192
193
194 !---------------------------------------------------------------------------
195 ! Print layer info
196 !---------------------------------------------------------------------------
197 if(verbose_ .gt. 0)then
198 write(*,'(" Reshape layer initialised")')
199 write(*,'(" Input shape: ",*(I0," x "))') this%input_shape
200 write(*,'(" Output shape: ",*(I0," x "))') this%output_shape
201 end if
202
203 end subroutine init_reshape
204 !###############################################################################
205
206
207 !##############################################################################!
208 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
209 !##############################################################################!
210
211
212 !###############################################################################
213 subroutine print_to_unit_reshape(this, unit)
214 !! Print reshape layer to unit
215 implicit none
216
217 ! Arguments
218 class(reshape_layer_type), intent(in) :: this
219 !! Instance of the reshape layer
220 integer, intent(in) :: unit
221 !! File unit
222
223 ! Local variables
224 character(100) :: fmt
225 !! Format string
226
227
228 ! Write initial parameters
229 !---------------------------------------------------------------------------
230 write(unit,'(3X,"INPUT_RANK = ",I0)') this%input_rank
231 write(fmt,'("(3X,""INPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%input_shape)
232 write(unit,fmt) this%input_shape
233 write(fmt,'("(3X,""OUTPUT_SHAPE ="",",I0,"(1X,I0))")') size(this%output_shape)
234 write(unit,fmt) this%output_shape
235
236 end subroutine print_to_unit_reshape
237 !###############################################################################
238
239
240 !###############################################################################
241 subroutine read_reshape(this, unit, verbose)
242 !! Read reshape layer from file
243 use athena__tools_infile, only: assign_val, assign_vec, get_val
244 use coreutils, only: to_lower, to_upper, icount
245 implicit none
246
247 ! Arguments
248 class(reshape_layer_type), intent(inout) :: this
249 !! Instance of the reshape layer
250 integer, intent(in) :: unit
251 !! File unit
252 integer, optional, intent(in) :: verbose
253 !! Verbosity level
254
255 ! Local variables
256 integer :: stat, verbose_ = 0
257 !! File status and verbosity level
258 integer :: itmp1 = 0
259 !! Temporary integer
260 integer :: input_rank = 0
261 !! Input rank
262 integer, dimension(:), allocatable :: input_shape, output_shape
263 !! Input shape
264 character(256) :: buffer, tag, err_msg
265 !! Buffer, tag, and error message
266
267
268 ! Initialise optional arguments
269 !---------------------------------------------------------------------------
270 if(present(verbose)) verbose_ = verbose
271
272
273 ! Loop over tags in layer card
274 !---------------------------------------------------------------------------
275 tag_loop: do
276
277 ! Check for end of file
278 !------------------------------------------------------------------------
279 read(unit,'(A)',iostat=stat) buffer
280 if(stat.ne.0)then
281 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
282 to_upper(this%name)
283 call stop_program(err_msg)
284 return
285 end if
286 if(trim(adjustl(buffer)).eq."") cycle tag_loop
287
288 ! Check for end of layer card
289 !------------------------------------------------------------------------
290 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
291 backspace(unit)
292 exit tag_loop
293 end if
294
295 tag=trim(adjustl(buffer))
296 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
297
298 ! Read parameters from save file
299 !------------------------------------------------------------------------
300 select case(trim(tag))
301 case("INPUT_SHAPE")
302 itmp1 = icount(get_val(buffer))
303 allocate(input_shape(itmp1), source=0)
304 call assign_vec(buffer, input_shape, itmp1)
305 case("OUTPUT_SHAPE")
306 itmp1 = icount(get_val(buffer))
307 allocate(output_shape(itmp1), source=0)
308 call assign_vec(buffer, output_shape, itmp1)
309 case default
310 ! Don't look for "e" due to scientific notation of numbers
311 ! ... i.e. exponent (E+00)
312 if( &
313 scan( &
314 to_lower(trim(adjustl(buffer))), &
315 'abcdfghijklmnopqrstuvwxyz' &
316 ) .eq. 0 &
317 )then
318 cycle tag_loop
319 elseif(tag(:3).eq.'END')then
320 cycle tag_loop
321 end if
322 write(err_msg,'("Unrecognised line in input file: ",A)') &
323 trim(adjustl(buffer))
324 call stop_program(err_msg)
325 return
326 end select
327 end do tag_loop
328
329 if(.not.allocated(output_shape))then
330 call stop_program('("Reshape layer missing OUTPUT_SHAPE")')
331 return
332 end if
333
334
335 ! Set hyperparameters and initialise layer
336 !---------------------------------------------------------------------------
337 call this%set_hyperparams(output_shape = output_shape, verbose = verbose_)
338 call this%init(input_shape = input_shape)
339
340
341 ! Check for end of layer card
342 !---------------------------------------------------------------------------
343 read(unit,'(A)') buffer
344 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
345 write(0,*) trim(adjustl(buffer))
346 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
347 call stop_program(err_msg)
348 return
349 end if
350
351
352 end subroutine read_reshape
353 !###############################################################################
354
355
356 !###############################################################################
357 function read_reshape_layer(unit, verbose) result(layer)
358 !! Read reshape layer from file
359 implicit none
360
361 ! Arguments
362 integer, intent(in) :: unit
363 !! File unit
364 integer, intent(in), optional :: verbose
365 !! Verbosity level
366 class(base_layer_type), allocatable :: layer
367 !! Instance of the reshape layer
368
369 ! Local variables
370 integer :: verbose_ = 0
371 !! Verbosity level
372
373 if(present(verbose)) verbose_ = verbose
374 allocate(layer, source=reshape_layer_type(output_shape=[0]))
375 call layer%read(unit, verbose=verbose_)
376
377 end function read_reshape_layer
378 !###############################################################################
379
380
381 !###############################################################################
382 subroutine build_from_onnx_reshape(this, node, initialisers, value_info, verbose)
383 !! Build reshape layer from ONNX node and initialiser
384 implicit none
385
386 ! Arguments
387 class(reshape_layer_type), intent(inout) :: this
388 !! Instance of the reshape layer
389 type(onnx_node_type), intent(in) :: node
390 !! ONNX node
391 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
392 !! ONNX initialisers
393 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
394 !! ONNX value infos
395 integer, intent(in) :: verbose
396 !! Verbosity level
397
398 ! Local variables
399 integer, dimension(:), allocatable :: output_shape
400 !! Output shape
401
402 ! Check size of initialisers is zero or one (shape can be an initialiser)
403 if(size(initialisers).gt.1)then
404 write(0,*) "WARNING: Multiple initialisers found for ONNX RESHAPE layer"
405 end if
406
407 ! Extract output shape from value_info (excluding batch dimension)
408 if(allocated(value_info(1)%dims))then
409 if(size(value_info(1)%dims).gt.1)then
410 output_shape = value_info(1)%dims(2:)
411 else
412 allocate(output_shape(1))
413 output_shape(1) = value_info(1)%dims(1)
414 end if
415 else
416 call stop_program("ONNX RESHAPE layer requires output shape in value_info")
417 return
418 end if
419
420 call this%set_hyperparams( &
421 output_shape = output_shape, &
422 verbose = verbose &
423 )
424
425 end subroutine build_from_onnx_reshape
426 !###############################################################################
427
428
429 !###############################################################################
430 function create_from_onnx_reshape_layer( &
431 node, initialisers, value_info, verbose &
432 ) result(layer)
433 !! Build reshape layer from ONNX node and initialiser
434 implicit none
435
436 ! Arguments
437 type(onnx_node_type), intent(in) :: node
438 !! ONNX node
439 type(onnx_initialiser_type), dimension(:), intent(in) :: initialisers
440 !! ONNX initialisers
441 type(onnx_tensor_type), dimension(:), intent(in) :: value_info
442 !! ONNX value infos
443 integer, intent(in), optional :: verbose
444 !! Verbosity level
445 class(base_layer_type), allocatable :: layer
446 !! Instance of the reshape layer
447
448 ! Local variables
449 integer :: verbose_ = 0
450 !! Verbosity level
451
452 if(present(verbose)) verbose_ = verbose
453 allocate(layer, source=reshape_layer_type(output_shape=[0]))
454 call layer%build_from_onnx(node, initialisers, value_info, verbose=verbose_)
455
456 end function create_from_onnx_reshape_layer
457 !###############################################################################
458
459
460 !##############################################################################!
461 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
462 !##############################################################################!
463
464
465 !###############################################################################
466 subroutine forward_reshape(this, input)
467 !! Forward propagation derived type handler
468 implicit none
469
470 ! Arguments
471 class(reshape_layer_type), intent(inout) :: this
472 !! Instance of the reshape layer
473 class(array_type), dimension(:,:), intent(in) :: input
474 !! Input array
475
476 type(array_type), pointer :: ptr => null()
477
478 ! Reshape input
479 !---------------------------------------------------------------------------
480 call this%output(1,1)%zero_grad()
481 ptr => reshape(input(1,1), this%output_shape)
482 call this%output(1,1)%assign_and_deallocate_source(ptr)
483 this%output(1,1)%is_temporary = .false.
484
485 end subroutine forward_reshape
486 !###############################################################################
487
488 end module athena__reshape_layer
489