| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | !!!############################################################################# | ||
| 2 | !!! Code written by Ned Thaddeus Taylor | ||
| 3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
| 4 | !!!############################################################################# | ||
| 5 | !!! module contains initialiser functions | ||
| 6 | !!! module includes the following procedures: | ||
| 7 | !!! - initialiser_setup - set up initialiser | ||
| 8 | !!! - get_default_initialiser - get default initialiser based on activation ... | ||
| 9 | !!! ... function | ||
| 10 | !!!############################################################################# | ||
| 11 | !! Examples of initialsers in keras: https://keras.io/api/layers/initializers/ | ||
| 12 | !!!############################################################################# | ||
| 13 | module initialiser | ||
| 14 | use misc, only: to_lower | ||
| 15 | use custom_types, only: initialiser_type | ||
| 16 | use initialiser_glorot, only: glorot_uniform, glorot_normal | ||
| 17 | use initialiser_he, only: he_uniform, he_normal | ||
| 18 | use initialiser_lecun, only: lecun_uniform, lecun_normal | ||
| 19 | use initialiser_ones, only: ones | ||
| 20 | use initialiser_zeros, only: zeros | ||
| 21 | use initialiser_ident, only: ident | ||
| 22 | use initialiser_gaussian, only: gaussian | ||
| 23 | implicit none | ||
| 24 | |||
| 25 | |||
| 26 | private | ||
| 27 | |||
| 28 | public :: initialiser_setup, get_default_initialiser | ||
| 29 | |||
| 30 | |||
| 31 | contains | ||
| 32 | |||
| 33 | !!!############################################################################# | ||
| 34 | !!! get default initialiser based on activation function (and if a bias) | ||
| 35 | !!!############################################################################# | ||
| 36 | !!! activation = (S, in) activation function name | ||
| 37 | !!! is_bias = (B, in) if true, then initialiser is for bias | ||
| 38 | !!! name = (S, out) name of default initialiser | ||
| 39 | 107 | function get_default_initialiser(activation, is_bias) result(name) | |
| 40 | implicit none | ||
| 41 | character(*), intent(in) :: activation | ||
| 42 | logical, optional, intent(in) :: is_bias | ||
| 43 | |||
| 44 | character(:), allocatable :: name | ||
| 45 | |||
| 46 | |||
| 47 | !!-------------------------------------------------------------------------- | ||
| 48 | !! if bias, use default initialiser of zero | ||
| 49 | !!-------------------------------------------------------------------------- | ||
| 50 | 107 | if(present(is_bias))then | |
| 51 |
4/10✓ Branch 0 taken 71 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 71 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 71 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 71 times.
|
71 | if(is_bias) name = "zeros" |
| 52 | 71 | return | |
| 53 | end if | ||
| 54 | |||
| 55 | |||
| 56 | !!-------------------------------------------------------------------------- | ||
| 57 | !! set default initialiser based on activation | ||
| 58 | !!-------------------------------------------------------------------------- | ||
| 59 |
3/4✓ Branch 1 taken 36 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 35 times.
|
36 | if(trim(activation).eq."selu")then |
| 60 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "lecun_normal" |
| 61 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 34 times.
|
35 | elseif(index(activation,"elu").ne.0)then |
| 62 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "he_uniform" |
| 63 |
3/4✓ Branch 1 taken 34 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 33 times.
|
34 | elseif(trim(activation).eq."batch")then |
| 64 |
3/8✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | name = "gaussian" |
| 65 | else | ||
| 66 |
3/8✓ Branch 0 taken 33 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
|
33 | name = "glorot_uniform" |
| 67 | end if | ||
| 68 | |||
| 69 |
4/4✓ Branch 0 taken 71 times.
✓ Branch 1 taken 36 times.
✓ Branch 2 taken 36 times.
✓ Branch 3 taken 71 times.
|
214 | end function get_default_initialiser |
| 70 | !!!############################################################################# | ||
| 71 | |||
| 72 | |||
| 73 | !!!############################################################################# | ||
| 74 | !!! set up initialiser | ||
| 75 | !!!############################################################################# | ||
| 76 | !!! name = (S, in) name of initialiser | ||
| 77 | !!! error = (I, out) error code | ||
| 78 | !!! initialiser = (O, out) initialiser function | ||
| 79 | 239 | function initialiser_setup(name, error) result(initialiser) | |
| 80 | implicit none | ||
| 81 | class(initialiser_type), allocatable :: initialiser | ||
| 82 | character(*), intent(in) :: name | ||
| 83 | integer, optional, intent(out) :: error | ||
| 84 | |||
| 85 | |||
| 86 | !!-------------------------------------------------------------------------- | ||
| 87 | !! set initialiser function | ||
| 88 | !!-------------------------------------------------------------------------- | ||
| 89 | 478 | select case(trim(to_lower(name))) | |
| 90 | case("glorot_uniform") | ||
| 91 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 29 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 29 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
29 | initialiser = glorot_uniform |
| 92 | case("glorot_normal") | ||
| 93 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = glorot_normal |
| 94 | case("he_uniform") | ||
| 95 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = he_uniform |
| 96 | case("he_normal") | ||
| 97 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = he_normal |
| 98 | case("lecun_uniform") | ||
| 99 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = lecun_uniform |
| 100 | case("lecun_normal") | ||
| 101 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = lecun_normal |
| 102 | case("ones") | ||
| 103 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 54 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 54 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
54 | initialiser = ones |
| 104 | case("zeros") | ||
| 105 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 118 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 118 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 118 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
118 | initialiser = zeros |
| 106 | case("ident") | ||
| 107 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = ident |
| 108 | case("gaussian") | ||
| 109 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 10 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 10 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
10 | initialiser = gaussian |
| 110 | case("normal") | ||
| 111 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 4 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
4 | initialiser = gaussian |
| 112 | case default | ||
| 113 |
12/16✓ Branch 0 taken 239 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 4 times.
✓ Branch 7 taken 4 times.
✓ Branch 8 taken 54 times.
✓ Branch 9 taken 118 times.
✓ Branch 10 taken 4 times.
✓ Branch 11 taken 10 times.
✓ Branch 12 taken 4 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
|
478 | if(present(error))then |
| 114 | ✗ | error = -1 | |
| 115 | ✗ | return | |
| 116 | else | ||
| 117 | ✗ | stop "Incorrect initialiser name given '"//trim(to_lower(name))//"'" | |
| 118 | end if | ||
| 119 | end select | ||
| 120 | |||
| 121 | 478 | end function initialiser_setup | |
| 122 | !!!############################################################################# | ||
| 123 | |||
| 124 | end module initialiser | ||
| 125 | !!!############################################################################# | ||
| 126 |