GCC Code Coverage Report


Directory: src/athena/
File: athena_initialiser.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__initialiser
2 !! Module containing functions to set up initialisers
3 !!
4 !! This module contains functions to set up initialisers for the weights and
5 !! biases of a neural network model
6 !! Examples of initialsers in keras: https://keras.io/api/layers/initializers/
7 use coreutils, only: stop_program, to_lower
8 use athena__misc_types, only: base_init_type
9 use athena__initialiser_glorot, only: &
10 glorot_uniform_init_type, glorot_normal_init_type
11 use athena__initialiser_he, only: he_uniform_init_type, he_normal_init_type
12 use athena__initialiser_lecun, only: &
13 lecun_uniform_init_type, lecun_normal_init_type
14 use athena__initialiser_ones, only: ones_init_type
15 use athena__initialiser_zeros, only: zeros_init_type
16 use athena__initialiser_ident, only: ident_init_type
17 use athena__initialiser_gaussian, only: gaussian_init_type
18 implicit none
19
20
21 private
22
23 public :: initialiser_setup, get_default_initialiser
24
25
26 contains
27
28 !###############################################################################
29 function get_default_initialiser(activation, is_bias) result(name)
30 !! Get the default initialiser based on the activation function
31 implicit none
32
33 ! Arguments
34 character(*), intent(in) :: activation
35 !! Activation function
36 logical, optional, intent(in) :: is_bias
37 !! Boolean whether initialiser is for bias
38
39 character(:), allocatable :: name
40
41
42 !---------------------------------------------------------------------------
43 ! If bias, use default initialiser of zero
44 !---------------------------------------------------------------------------
45 if(present(is_bias))then
46 if(is_bias) name = "zeros"
47 return
48 end if
49
50
51 !---------------------------------------------------------------------------
52 ! Set default initialiser based on activation
53 !---------------------------------------------------------------------------
54 if(trim(activation).eq."selu")then
55 name = "lecun_normal"
56 elseif(index(activation,"elu").ne.0)then
57 name = "he_uniform"
58 elseif(trim(activation).eq."batch")then
59 name = "gaussian"
60 else
61 name = "glorot_uniform"
62 end if
63
64 end function get_default_initialiser
65 !###############################################################################
66
67
68 !###############################################################################
69 function initialiser_setup(input, error) result(initialiser)
70 !! Set up the initialiser function
71 implicit none
72
73 ! Arguments
74 class(base_init_type), allocatable :: initialiser
75 !! Initialiser function
76 class(*) :: input
77 !! Name of initialiser or initialiser object
78 integer, optional, intent(out) :: error
79 !! Error code
80
81 ! Local variables
82 character(256) :: err_msg
83 !! Error message
84
85
86 !---------------------------------------------------------------------------
87 ! Set initialiser function
88 !---------------------------------------------------------------------------
89 select type(input)
90 class is(base_init_type)
91 initialiser = input
92 type is(character(*))
93 select case(trim(to_lower(input)))
94 case("glorot_uniform")
95 initialiser = glorot_uniform_init_type()
96 case("glorot_normal")
97 initialiser = glorot_normal_init_type()
98 case("he_uniform")
99 initialiser = he_uniform_init_type()
100 case("he_normal")
101 initialiser = he_normal_init_type()
102 case("lecun_uniform")
103 initialiser = lecun_uniform_init_type()
104 case("lecun_normal")
105 initialiser = lecun_normal_init_type()
106 case("ones")
107 initialiser = ones_init_type()
108 case("zeros")
109 initialiser = zeros_init_type()
110 case("ident")
111 initialiser = ident_init_type()
112 case("gaussian")
113 initialiser = gaussian_init_type()
114 case("normal")
115 initialiser = gaussian_init_type(name="normal")
116 case default
117 if(present(error))then
118 error = -1
119 return
120 else
121 write(err_msg,'("Incorrect initialiser name given ''",A,"''")') &
122 trim(to_lower(input))
123 call stop_program(trim(err_msg))
124 return
125 end if
126 end select
127 class default
128 if(present(error))then
129 error = -1
130 return
131 else
132 write(err_msg,'("Unknown input type given for initialiser setup")')
133 call stop_program(trim(err_msg))
134 return
135 end if
136 end select
137
138 end function initialiser_setup
139 !###############################################################################
140
141 end module athena__initialiser
142