Line | Branch | Exec | Source |
---|---|---|---|
1 | !!!############################################################################# | ||
2 | !!! Code written by Ned Thaddeus Taylor | ||
3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
4 | !!!############################################################################# | ||
5 | !!! module contains initialiser functions | ||
6 | !!! module includes the following procedures: | ||
7 | !!! - initialiser_setup - set up initialiser | ||
8 | !!! - get_default_initialiser - get default initialiser based on activation ... | ||
9 | !!! ... function | ||
10 | !!!############################################################################# | ||
11 | !! Examples of initialsers in keras: https://keras.io/api/layers/initializers/ | ||
12 | !!!############################################################################# | ||
13 | module initialiser | ||
14 | use misc, only: to_lower | ||
15 | use custom_types, only: initialiser_type | ||
16 | use initialiser_glorot, only: glorot_uniform, glorot_normal | ||
17 | use initialiser_he, only: he_uniform, he_normal | ||
18 | use initialiser_lecun, only: lecun_uniform, lecun_normal | ||
19 | use initialiser_ones, only: ones | ||
20 | use initialiser_zeros, only: zeros | ||
21 | use initialiser_ident, only: ident | ||
22 | use initialiser_gaussian, only: gaussian | ||
23 | implicit none | ||
24 | |||
25 | |||
26 | private | ||
27 | |||
28 | public :: initialiser_setup, get_default_initialiser | ||
29 | |||
30 | |||
31 | contains | ||
32 | |||
33 | !!!############################################################################# | ||
34 | !!! get default initialiser based on activation function (and if a bias) | ||
35 | !!!############################################################################# | ||
36 | !!! activation = (S, in) activation function name | ||
37 | !!! is_bias = (B, in) if true, then initialiser is for bias | ||
38 | !!! name = (S, out) name of default initialiser | ||
39 | 107 | function get_default_initialiser(activation, is_bias) result(name) | |
40 | implicit none | ||
41 | character(*), intent(in) :: activation | ||
42 | logical, optional, intent(in) :: is_bias | ||
43 | |||
44 | character(:), allocatable :: name | ||
45 | |||
46 | |||
47 | !!-------------------------------------------------------------------------- | ||
48 | !! if bias, use default initialiser of zero | ||
49 | !!-------------------------------------------------------------------------- | ||
50 | 107 | if(present(is_bias))then | |
51 |
4/10✓ Branch 0 taken 71 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 71 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 71 times.
|
71 | if(is_bias) name = "zeros" |
52 | 71 | return | |
53 | end if | ||
54 | |||
55 | |||
56 | !!-------------------------------------------------------------------------- | ||
57 | !! set default initialiser based on activation | ||
58 | !!-------------------------------------------------------------------------- | ||
59 |
3/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 35 times.
|
36 | if(trim(activation).eq."selu")then |
60 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "lecun_normal" |
61 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 34 times.
|
35 | elseif(index(activation,"elu").ne.0)then |
62 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "he_uniform" |
63 |
3/4✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 33 times.
|
34 | elseif(trim(activation).eq."batch")then |
64 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "gaussian" |
65 | else | ||
66 |
3/8✓ Branch 0 taken 33 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
|
33 | name = "glorot_uniform" |
67 | end if | ||
68 | |||
69 |
4/4✓ Branch 0 taken 71 times.
✓ Branch 1 taken 36 times.
✓ Branch 2 taken 36 times.
✓ Branch 3 taken 71 times.
|
214 | end function get_default_initialiser |
70 | !!!############################################################################# | ||
71 | |||
72 | |||
73 | !!!############################################################################# | ||
74 | !!! set up initialiser | ||
75 | !!!############################################################################# | ||
76 | !!! name = (S, in) name of initialiser | ||
77 | !!! error = (I, out) error code | ||
78 | !!! initialiser = (O, out) initialiser function | ||
79 | 239 | function initialiser_setup(name, error) result(initialiser) | |
80 | implicit none | ||
81 | class(initialiser_type), allocatable :: initialiser | ||
82 | character(*), intent(in) :: name | ||
83 | integer, optional, intent(out) :: error | ||
84 | |||
85 | |||
86 | !!-------------------------------------------------------------------------- | ||
87 | !! set initialiser function | ||
88 | !!-------------------------------------------------------------------------- | ||
89 | 478 | select case(trim(to_lower(name))) | |
90 | case("glorot_uniform") | ||
91 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 29 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 29 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
29 | initialiser = glorot_uniform |
92 | case("glorot_normal") | ||
93 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = glorot_normal |
94 | case("he_uniform") | ||
95 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = he_uniform |
96 | case("he_normal") | ||
97 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = he_normal |
98 | case("lecun_uniform") | ||
99 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = lecun_uniform |
100 | case("lecun_normal") | ||
101 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = lecun_normal |
102 | case("ones") | ||
103 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 54 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 54 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
54 | initialiser = ones |
104 | case("zeros") | ||
105 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 118 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 118 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 118 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
118 | initialiser = zeros |
106 | case("ident") | ||
107 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = ident |
108 | case("gaussian") | ||
109 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 10 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 10 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
10 | initialiser = gaussian |
110 | case("normal") | ||
111 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = gaussian |
112 | case default | ||
113 |
12/16✓ Branch 0 taken 239 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 4 times.
✓ Branch 8 taken 54 times.
✓ Branch 9 taken 118 times.
✓ Branch 10 taken 4 times.
✓ Branch 11 taken 10 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
478 | if(present(error))then |
114 | ✗ | error = -1 | |
115 | ✗ | return | |
116 | else | ||
117 | ✗ | stop "Incorrect initialiser name given '"//trim(to_lower(name))//"'" | |
118 | end if | ||
119 | end select | ||
120 | |||
121 | 478 | end function initialiser_setup | |
122 | !!!############################################################################# | ||
123 | |||
124 | end module initialiser | ||
125 | !!!############################################################################# | ||
126 |