GCC Code Coverage Report


Directory: src/athena/
File: athena_recurrent_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__recurrent_layer
2 !! Module containing implementation of recurrent neural network layers
3 !!
4 !! This module implements the simple recurrent neural network (RNN) layer,
5 !! which is designed to handle sequential data by maintaining a hidden state.
6 !!
7 !! **Simple RNN layer (equivalent to RNNCell of PyTorch):**
8 !! \[
9 !! \begin{align}
10 !! \mathbf{h}_t &= \sigma(\mathbf{W}_{ih}\mathbf{x}_t + \mathbf{W}_{hh}\mathbf{h}_{t-1} + \mathbf{b}_h) \\
11 !! \mathbf{y}_t &= \mathbf{W}_{ho}\mathbf{h}_t + \mathbf{b}_o
12 !! \end{align}
13 !! \]
14 !!
15 !! where:
16 !! - \(\mathbf{x}_t\) is input at time t
17 !! - \(\mathbf{h}_t\) is hidden state at time t
18 !! - \(\sigma\) is the activation function (e.g., tanh, relu)
19 !! - \(\mathbf{W}\) matrices are learnable weights
20 !! - \(\mathbf{b}\) vectors are learnable biases
21 !!
22 !! Properties:
23 !! - Processes sequential data with temporal dependencies
24 !! - Maintains hidden state across time steps
25 use coreutils, only: real32, stop_program
26 use athena__base_layer, only: learnable_layer_type, base_layer_type
27 use athena__misc_types, only: base_actv_type, base_init_type, &
28 onnx_node_type, onnx_initialiser_type
29 use diffstruc, only: array_type, matmul, operator(+), operator(*)
30 implicit none
31
32
33 private
34
35 public :: recurrent_layer_type
36 public :: read_recurrent_layer
37
38
39 type, extends(learnable_layer_type) :: recurrent_layer_type
40 !! Type for simple RNN layer
41 integer :: hidden_size
42 !! Size of hidden state
43 integer :: input_size
44 !! Size of input
45 integer :: time_step
46 !! Current time step
47 type(array_type), pointer :: hidden_state => null()
48 !! Hidden state
49 contains
50 procedure, pass(this) :: get_num_params => get_num_params_recurrent
51 procedure, pass(this) :: set_hyperparams => set_hyperparams_recurrent
52 procedure, pass(this) :: init => init_recurrent
53 procedure, pass(this) :: print_to_unit => print_to_unit_recurrent
54 procedure, pass(this) :: read => read_recurrent
55 procedure, pass(this) :: forward => forward_recurrent
56 procedure, pass(this) :: reset_state => reset_state_recurrent
57 end type recurrent_layer_type
58
59 interface recurrent_layer_type
60 module function layer_setup( &
61 hidden_size, input_size, use_bias, &
62 activation, &
63 kernel_initialiser, bias_initialiser, verbose &
64 ) result(layer)
65 integer, intent(in) :: hidden_size
66 integer, optional, intent(in) :: input_size
67 logical, optional, intent(in) :: use_bias
68 class(*), optional, intent(in) :: activation
69 class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
70 integer, optional, intent(in) :: verbose
71 type(recurrent_layer_type) :: layer
72 end function layer_setup
73 end interface recurrent_layer_type
74
75
76
77 contains
78
79 !###############################################################################
80 pure function get_num_params_recurrent(this) result(num_params)
81 implicit none
82 class(recurrent_layer_type), intent(in) :: this
83 integer :: num_params
84
85 num_params = &
86 this%hidden_size * this%input_size + & ! W_ih
87 this%hidden_size * this%hidden_size ! W_hh
88 if(this%use_bias) then
89 num_params = num_params + 2 * this%hidden_size ! b_h + b_o
90 end if
91
92 end function get_num_params_recurrent
93 !###############################################################################
94
95
96 !###############################################################################
97 subroutine reset_state_recurrent(this)
98 !! Reset the hidden state of the recurrent layer
99 implicit none
100
101 ! Arguments
102 class(recurrent_layer_type), intent(inout) :: this
103 !! Instance of the recurrent layer
104
105 this%time_step = 0
106 if(associated(this%hidden_state))then
107 call this%hidden_state%deallocate()
108 nullify(this%hidden_state)
109 end if
110
111 end subroutine reset_state_recurrent
112 !###############################################################################
113
114
115 !##############################################################################!
116 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
117 !##############################################################################!
118
119
120 !###############################################################################
121 module function layer_setup( &
122 hidden_size, input_size, use_bias, &
123 activation, &
124 kernel_initialiser, bias_initialiser, verbose &
125 ) result(layer)
126 !! Setup a recurrent layer
127 use athena__activation, only: activation_setup
128 use athena__initialiser, only: initialiser_setup
129 implicit none
130
131 ! Arguments
132 integer, intent(in) :: hidden_size
133 !! Size of hidden state
134 integer, optional, intent(in) :: input_size
135 !! Size of input
136 logical, optional, intent(in) :: use_bias
137 !! Whether to use bias
138 class(*), optional, intent(in) :: activation
139 !! Activation function
140 class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
141 !! Activation function, kernel initialiser, and bias initialiser
142 integer, optional, intent(in) :: verbose
143 !! Verbosity level
144
145 type(recurrent_layer_type) :: layer
146 !! Instance of the recurrent layer
147
148 ! Local variables
149 integer :: verbose_ = 0
150 !! Verbosity level
151 logical :: use_bias_ = .true.
152 !! Whether to use bias
153 class(base_actv_type), allocatable :: activation_
154 !! Activation function
155 class(base_init_type), allocatable :: kernel_initialiser_, bias_initialiser_
156 !! Kernel and bias initialisers
157
158 if(present(verbose)) verbose_ = verbose
159
160
161 !---------------------------------------------------------------------------
162 ! Set use_bias
163 !---------------------------------------------------------------------------
164 if(present(use_bias)) use_bias_ = use_bias
165
166
167 !---------------------------------------------------------------------------
168 ! Set activation functions based on input name
169 !---------------------------------------------------------------------------
170 if(present(activation))then
171 activation_ = activation_setup(activation)
172 else
173 activation_ = activation_setup("tanh")
174 end if
175
176
177 !---------------------------------------------------------------------------
178 ! Define weights (kernels) and biases initialisers
179 !---------------------------------------------------------------------------
180 if(present(kernel_initialiser))then
181 kernel_initialiser_ = initialiser_setup(kernel_initialiser)
182 end if
183 if(present(bias_initialiser))then
184 bias_initialiser_ = initialiser_setup(bias_initialiser)
185 end if
186
187
188 !---------------------------------------------------------------------------
189 ! Set hyperparameters
190 !---------------------------------------------------------------------------
191 call layer%set_hyperparams( &
192 hidden_size = hidden_size, &
193 use_bias = use_bias_, &
194 activation = activation_, &
195 kernel_initialiser = kernel_initialiser_, &
196 bias_initialiser = bias_initialiser_, &
197 verbose = verbose_ &
198 )
199
200
201 !---------------------------------------------------------------------------
202 ! Initialise layer shape
203 !---------------------------------------------------------------------------
204 if(present(input_size)) call layer%init(input_shape=[input_size])
205
206 end function layer_setup
207 !###############################################################################
208
209
210 !###############################################################################
211 subroutine set_hyperparams_recurrent( &
212 this, hidden_size, &
213 use_bias, &
214 activation, &
215 kernel_initialiser, bias_initialiser, &
216 verbose &
217 )
218 !! Set the hyperparameters for fully connected layer
219 use athena__activation, only: activation_setup
220 use athena__initialiser, only: get_default_initialiser, initialiser_setup
221 implicit none
222
223 ! Arguments
224 class(recurrent_layer_type), intent(inout) :: this
225 !! Instance of the recurrent layer
226 integer, intent(in) :: hidden_size
227 !! Number of hidden units
228 logical, intent(in) :: use_bias
229 !! Whether to use bias
230 class(base_actv_type), allocatable, intent(in) :: activation
231 !! Activation function
232 class(base_init_type), allocatable, intent(in) :: &
233 kernel_initialiser, bias_initialiser
234 !! Kernel and bias initialisers
235 integer, optional, intent(in) :: verbose
236 !! Verbosity level
237
238 ! Local variables
239 character(len=256) :: buffer
240
241
242 this%name = "recu"
243 this%type = "recurrent"
244 this%input_rank = 1
245 this%output_rank = 1
246 this%use_bias = use_bias
247 this%hidden_size = hidden_size
248 if(allocated(this%activation)) deallocate(this%activation)
249 if(.not.allocated(activation))then
250 this%activation = activation_setup("none")
251 else
252 allocate(this%activation, source=activation)
253 end if
254 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
255 if(.not.allocated(kernel_initialiser))then
256 buffer = get_default_initialiser(this%activation%name)
257 this%kernel_init = initialiser_setup(buffer)
258 else
259 allocate(this%kernel_init, source=kernel_initialiser)
260 end if
261 if(allocated(this%bias_init)) deallocate(this%bias_init)
262 if(.not.allocated(bias_initialiser))then
263 buffer = get_default_initialiser( &
264 this%activation%name, &
265 is_bias=.true. &
266 )
267 this%bias_init = initialiser_setup(buffer)
268 else
269 allocate(this%bias_init, source=bias_initialiser)
270 end if
271 if(present(verbose))then
272 if(abs(verbose).gt.0)then
273 write(*,'("RECU activation function: ",A)') &
274 trim(this%activation%name)
275 write(*,'("RECU kernel initialiser: ",A)') &
276 trim(this%kernel_init%name)
277 write(*,'("RECU bias initialiser: ",A)') &
278 trim(this%bias_init%name)
279 end if
280 end if
281
282 end subroutine set_hyperparams_recurrent
283 !###############################################################################
284
285
286 !###############################################################################
287 subroutine init_recurrent(this, input_shape, verbose)
288 !! Initialise the recurrent layer
289 implicit none
290
291 ! Arguments
292 class(recurrent_layer_type), intent(inout) :: this
293 !! Instance of the recurrent layer
294 integer, dimension(:), intent(in) :: input_shape
295 !! Shape of the input
296 integer, optional, intent(in) :: verbose
297 !! Verbosity level
298
299 ! Local variables
300 integer :: num_inputs
301 !! Temporary variable
302 integer :: verbose_ = 0
303
304
305 !---------------------------------------------------------------------------
306 ! Initialise optional arguments
307 !---------------------------------------------------------------------------
308 if(present(verbose)) verbose_ = verbose
309
310
311 !---------------------------------------------------------------------------
312 ! Initialise number of inputs
313 !---------------------------------------------------------------------------
314 if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
315 this%input_size = this%input_shape(1)
316 this%output_shape = [this%hidden_size]
317 this%num_params = this%get_num_params()
318
319
320 !---------------------------------------------------------------------------
321 ! Allocate weight, weight steps (velocities), output, and activation
322 !---------------------------------------------------------------------------
323 allocate(this%weight_shape(2,2))
324 this%weight_shape(:,1) = [ this%hidden_size, this%input_size ]
325 this%weight_shape(:,2) = [ this%hidden_size, this%hidden_size ]
326
327 if(this%use_bias)then
328 this%bias_shape = [ this%hidden_size, this%hidden_size ]
329 allocate(this%params(4))
330 else
331 allocate(this%params(2))
332 end if
333 call this%params(1)%allocate([this%weight_shape(:,1), 1])
334 call this%params(1)%set_requires_grad(.true.)
335 this%params(1)%fix_pointer = .true.
336 this%params(1)%is_sample_dependent = .false.
337 this%params(1)%is_temporary = .false.
338 call this%params(2)%allocate([this%weight_shape(:,2), 1])
339 call this%params(2)%set_requires_grad(.true.)
340 this%params(2)%fix_pointer = .true.
341 this%params(2)%is_sample_dependent = .false.
342 this%params(2)%is_temporary = .false.
343
344 num_inputs = this%input_size + this%hidden_size
345 if(this%use_bias)then
346 num_inputs = num_inputs + 2 * this%hidden_size
347 call this%params(3)%allocate([this%bias_shape(1), 1])
348 call this%params(3)%set_requires_grad(.true.)
349 this%params(3)%fix_pointer = .true.
350 this%params(3)%is_sample_dependent = .false.
351 this%params(3)%is_temporary = .false.
352 call this%params(4)%allocate([this%bias_shape(2), 1])
353 call this%params(4)%set_requires_grad(.true.)
354 this%params(4)%fix_pointer = .true.
355 this%params(4)%is_sample_dependent = .false.
356 this%params(4)%is_temporary = .false.
357 end if
358
359
360 !---------------------------------------------------------------------------
361 ! Initialise weights (kernels)
362 !---------------------------------------------------------------------------
363 call this%kernel_init%initialise( &
364 this%params(1)%val(:,1), &
365 fan_in = num_inputs, fan_out = this%hidden_size, &
366 spacing = [ this%hidden_size ] &
367 )
368 call this%kernel_init%initialise( &
369 this%params(2)%val(:,1), &
370 fan_in = num_inputs, fan_out = this%hidden_size, &
371 spacing = [ this%hidden_size ] &
372 )
373
374 ! Initialise biases
375 !---------------------------------------------------------------------------
376 if(this%use_bias)then
377 call this%bias_init%initialise( &
378 this%params(3)%val(:,1), &
379 fan_in = num_inputs, fan_out = this%hidden_size &
380 )
381 call this%bias_init%initialise( &
382 this%params(4)%val(:,1), &
383 fan_in = num_inputs, fan_out = this%hidden_size &
384 )
385 end if
386
387
388 !---------------------------------------------------------------------------
389 ! Allocate arrays and initialise time_step
390 !---------------------------------------------------------------------------
391 if(allocated(this%output)) deallocate(this%output)
392 allocate(this%output(1,1))
393 this%time_step = 0
394
395 end subroutine init_recurrent
396 !###############################################################################
397
398
399 !##############################################################################!
400 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
401 !##############################################################################!
402
403
404 !###############################################################################
405 subroutine print_to_unit_recurrent(this, unit)
406 !! Print recurrent layer to unit
407 use coreutils, only: to_upper
408 implicit none
409
410 ! Arguments
411 class(recurrent_layer_type), intent(in) :: this
412 !! Instance of the fully connected layer
413 integer, intent(in) :: unit
414 !! File unit
415
416
417 ! Write initial parameters
418 !---------------------------------------------------------------------------
419 write(unit,'(3X,"INPUT_SIZE = ",I0)') this%input_size
420 write(unit,'(3X,"HIDDEN_SIZE = ",I0)') this%hidden_size
421
422 write(unit,'(3X,"USE_BIAS = ",L1)') this%use_bias
423 if(this%activation%name .ne. 'none')then
424 call this%activation%print_to_unit(unit)
425 end if
426
427
428 ! Write fully connected weights and biases
429 !---------------------------------------------------------------------------
430 write(unit,'("WEIGHTS")')
431 write(unit,'(5(E16.8E2))') this%params(1)%val(:,1)
432 write(unit,'(5(E16.8E2))') this%params(2)%val(:,1)
433 if(this%use_bias)then
434 write(unit,'(5(E16.8E2))') this%params(3)%val(:,1)
435 write(unit,'(5(E16.8E2))') this%params(4)%val(:,1)
436 end if
437 write(unit,'("END WEIGHTS")')
438
439 end subroutine print_to_unit_recurrent
440 !###############################################################################
441
442
443 !###############################################################################
444 subroutine read_recurrent(this, unit, verbose)
445 !! Read recurrent layer from file
446 use athena__tools_infile, only: assign_val, assign_vec, move
447 use coreutils, only: to_lower, to_upper, icount
448 use athena__activation, only: read_activation
449 use athena__initialiser, only: initialiser_setup
450 implicit none
451
452 ! Arguments
453 class(recurrent_layer_type), intent(inout) :: this
454 !! Instance of the recurrent layer
455 integer, intent(in) :: unit
456 !! Unit number
457 integer, optional, intent(in) :: verbose
458 !! Verbosity level
459
460 ! Local variables
461 integer :: stat
462 !! Status of read
463 integer :: verbose_ = 0
464 !! Verbosity level
465 integer :: i, j, k, c, itmp1, iline, num_params
466 !! Loop variables and temporary integer
467 integer :: input_size, hidden_size
468 !! Input and hidden sizes
469 logical :: use_bias = .true.
470 !! Whether to use bias
471 character(14) :: kernel_initialiser_name='', bias_initialiser_name=''
472 !! Initialisers
473 character(20) :: activation_name=''
474 !! Activation function
475 class(base_actv_type), allocatable :: activation
476 !! Activation function
477 class(base_init_type), allocatable :: kernel_initialiser, bias_initialiser
478 !! Initialisers
479 character(256) :: buffer, tag, err_msg
480 !! Buffer, tag, and error message
481 integer, dimension(2) :: input_shape
482 !! Input shape
483 real(real32), allocatable, dimension(:) :: data_list
484 !! Data list
485 integer :: param_line, final_line
486 !! Parameter line number
487
488
489 ! Initialise optional arguments
490 !---------------------------------------------------------------------------
491 if(present(verbose)) verbose_ = verbose
492
493
494 ! Loop over tags in layer card
495 !---------------------------------------------------------------------------
496 iline = 0
497 param_line = 0
498 final_line = 0
499 tag_loop: do
500
501 ! Check for end of file
502 !------------------------------------------------------------------------
503 read(unit,'(A)',iostat=stat) buffer
504 if(stat.ne.0)then
505 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
506 to_upper(this%name)
507 call stop_program(err_msg)
508 return
509 end if
510 if(trim(adjustl(buffer)).eq."") cycle tag_loop
511
512 ! Check for end of layer card
513 !------------------------------------------------------------------------
514 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
515 final_line = iline
516 backspace(unit)
517 exit tag_loop
518 end if
519 iline = iline + 1
520
521 tag=trim(adjustl(buffer))
522 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
523
524 ! Read parameters from file
525 !------------------------------------------------------------------------
526 select case(trim(tag))
527 case("INPUT_SIZE", "NUM_INPUTS")
528 call assign_val(buffer, input_size, itmp1)
529 case("HIDDEN_SIZE", "NUM_OUTPUTS")
530 call assign_val(buffer, hidden_size, itmp1)
531 case("USE_BIAS")
532 call assign_val(buffer, use_bias, itmp1)
533 case("ACTIVATION")
534 iline = iline - 1
535 backspace(unit)
536 activation = read_activation(unit, iline)
537 case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALIZER")
538 call assign_val(buffer, kernel_initialiser_name, itmp1)
539 case("BIAS_INITIALISER", "BIAS_INIT", "BIAS_INITIALIZER")
540 call assign_val(buffer, bias_initialiser_name, itmp1)
541 case("WEIGHTS")
542 kernel_initialiser_name = 'zeros'
543 bias_initialiser_name = 'zeros'
544 param_line = iline
545 case default
546 ! Don't look for "e" due to scientific notation of numbers
547 ! ... i.e. exponent (E+00)
548 if(scan(to_lower(trim(adjustl(buffer))),&
549 'abcdfghijklmnopqrstuvwxyz').eq.0)then
550 cycle tag_loop
551 elseif(tag(:3).eq.'END')then
552 cycle tag_loop
553 end if
554 write(err_msg,'("Unrecognised line in input file: ",A)') &
555 trim(adjustl(buffer))
556 call stop_program(err_msg)
557 return
558 end select
559 end do tag_loop
560 kernel_initialiser = initialiser_setup(kernel_initialiser_name)
561 bias_initialiser = initialiser_setup(bias_initialiser_name)
562
563
564 ! Set hyperparameters and initialise layer
565 !---------------------------------------------------------------------------
566 call this%set_hyperparams( &
567 hidden_size = hidden_size, &
568 use_bias = use_bias, &
569 activation = activation, &
570 kernel_initialiser = kernel_initialiser, &
571 bias_initialiser = bias_initialiser, &
572 verbose = verbose_ &
573 )
574 call this%init(input_shape=[input_size])
575
576
577 ! Check if WEIGHTS card was found
578 !---------------------------------------------------------------------------
579 if(param_line.eq.0)then
580 write(0,*) "WARNING: WEIGHTS card in "//to_upper(trim(this%name))//" not found"
581 else
582 call move(unit, param_line - iline, iostat=stat)
583 num_params = this%input_size * this%hidden_size
584 allocate(data_list(num_params), source=0._real32)
585 c = 1
586 k = 1
587 data_concat_loop: do while(c.le.num_params)
588 read(unit,'(A)',iostat=stat) buffer
589 if(stat.ne.0) exit data_concat_loop
590 k = icount(buffer)
591 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
592 c = c + k
593 end do data_concat_loop
594 this%params(1)%val(:,1) = data_list
595 deallocate(data_list)
596 num_params = this%hidden_size * this%hidden_size
597 allocate(data_list(num_params), source=0._real32)
598 c = 1
599 k = 1
600 data_concat_loop1: do while(c.le.num_params)
601 read(unit,'(A)',iostat=stat) buffer
602 if(stat.ne.0) exit data_concat_loop1
603 k = icount(buffer)
604 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
605 c = c + k
606 end do data_concat_loop1
607 this%params(2)%val(:,1) = data_list
608 deallocate(data_list)
609 if(use_bias)then
610 do i = 1, 2
611 hidden_size = this%hidden_size
612 allocate(data_list(hidden_size), source=0._real32)
613 c = 1
614 k = 1
615 data_concat_loop_bias: do while(c.le.hidden_size)
616 read(unit,'(A)',iostat=stat) buffer
617 if(stat.ne.0) exit data_concat_loop_bias
618 k = icount(buffer)
619 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
620 c = c + k
621 end do data_concat_loop_bias
622 this%params(i+2)%val(:,1) = data_list(1:hidden_size)
623 deallocate(data_list)
624 end do
625 end if
626
627 ! Check for end of weights card
628 !------------------------------------------------------------------------
629 read(unit,'(A)') buffer
630 if(trim(adjustl(buffer)).ne."END WEIGHTS")then
631 write(0,*) trim(adjustl(buffer))
632 call stop_program("END WEIGHTS not where expected")
633 return
634 end if
635 end if
636
637
638 !---------------------------------------------------------------------------
639 ! Check for end of layer card
640 !---------------------------------------------------------------------------
641 call move(unit, final_line - iline, iostat=stat)
642 read(unit,'(A)') buffer
643 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
644 write(0,*) trim(adjustl(buffer))
645 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
646 call stop_program(err_msg)
647 return
648 end if
649
650 end subroutine read_recurrent
651 !###############################################################################
652
653
654 !###############################################################################
655 function read_recurrent_layer(unit, verbose) result(layer)
656 !! Read recurrent layer from file and return layer
657 implicit none
658
659 ! Arguments
660 integer, intent(in) :: unit
661 !! Unit number
662 integer, optional, intent(in) :: verbose
663 !! Verbosity level
664 class(base_layer_type), allocatable :: layer
665 !! Instance of the fully connected layer
666
667 ! Local variables
668 integer :: verbose_ = 0
669 !! Verbosity level
670
671 if(present(verbose)) verbose_ = verbose
672 allocate(layer, source=recurrent_layer_type(hidden_size=0))
673 call layer%read(unit, verbose=verbose_)
674
675 end function read_recurrent_layer
676 !###############################################################################
677
678
679 !##############################################################################!
680 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
681 !##############################################################################!
682
683
684 !###############################################################################
685 subroutine forward_recurrent(this, input)
686 !! Forward propagation
687 implicit none
688
689 ! Arguments
690 class(recurrent_layer_type), intent(inout) :: this
691 !! Instance of the recurrent layer
692 class(array_type), dimension(:,:), intent(in) :: input
693 !! Input values
694
695 type(array_type), pointer :: ptr1, ptr2, ptr
696
697 if(.not.associated(this%hidden_state))then
698 call this%reset_state()
699 allocate(this%hidden_state)
700 call this%hidden_state%allocate( &
701 [this%hidden_size, size(input(1,1)%val,2)], &
702 source = 0._real32 &
703 )
704 this%hidden_state%is_temporary = .false.
705 end if
706
707
708 ! Generate outputs from weights, biases, and inputs
709 !---------------------------------------------------------------------------
710 if(this%use_bias)then
711 ptr1 => matmul(this%params(1), input(1,1) ) + this%params(3)
712 ptr2 => matmul(this%params(2), this%hidden_state ) + this%params(4)
713 else
714 ptr1 => matmul(this%params(1), input(1,1) )
715 ptr2 => matmul(this%params(2), this%hidden_state )
716 end if
717 ptr => ptr1 + ptr2
718
719 ! Apply activation function to activation
720 !---------------------------------------------------------------------------
721 call this%output(1,1)%zero_grad()
722 if(trim(this%activation%name) .ne. "none") then
723 ptr => this%activation%apply(ptr)
724 end if
725 this%hidden_state => ptr
726 call this%output(1,1)%assign_shallow(ptr)
727 this%output(1,1)%is_temporary = .false.
728 this%time_step = this%time_step + 1
729
730 end subroutine forward_recurrent
731 !###############################################################################
732
733 end module athena__recurrent_layer
734