| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__initialiser | ||
| 2 | !! Module containing functions to set up initialisers | ||
| 3 | !! | ||
| 4 | !! This module contains functions to set up initialisers for the weights and | ||
| 5 | !! biases of a neural network model | ||
| 6 | !! Examples of initialsers in keras: https://keras.io/api/layers/initializers/ | ||
| 7 | use coreutils, only: stop_program, to_lower | ||
| 8 | use athena__misc_types, only: base_init_type | ||
| 9 | use athena__initialiser_glorot, only: & | ||
| 10 | glorot_uniform_init_type, glorot_normal_init_type | ||
| 11 | use athena__initialiser_he, only: he_uniform_init_type, he_normal_init_type | ||
| 12 | use athena__initialiser_lecun, only: & | ||
| 13 | lecun_uniform_init_type, lecun_normal_init_type | ||
| 14 | use athena__initialiser_ones, only: ones_init_type | ||
| 15 | use athena__initialiser_zeros, only: zeros_init_type | ||
| 16 | use athena__initialiser_ident, only: ident_init_type | ||
| 17 | use athena__initialiser_gaussian, only: gaussian_init_type | ||
| 18 | implicit none | ||
| 19 | |||
| 20 | |||
| 21 | private | ||
| 22 | |||
| 23 | public :: initialiser_setup, get_default_initialiser | ||
| 24 | |||
| 25 | |||
| 26 | contains | ||
| 27 | |||
| 28 | !############################################################################### | ||
| 29 | − | function get_default_initialiser(activation, is_bias) result(name) | |
| 30 | !! Get the default initialiser based on the activation function | ||
| 31 | implicit none | ||
| 32 | |||
| 33 | ! Arguments | ||
| 34 | character(*), intent(in) :: activation | ||
| 35 | !! Activation function | ||
| 36 | logical, optional, intent(in) :: is_bias | ||
| 37 | !! Boolean whether initialiser is for bias | ||
| 38 | |||
| 39 | character(:), allocatable :: name | ||
| 40 | |||
| 41 | |||
| 42 | !--------------------------------------------------------------------------- | ||
| 43 | ! If bias, use default initialiser of zero | ||
| 44 | !--------------------------------------------------------------------------- | ||
| 45 | − | if(present(is_bias))then | |
| 46 | − | if(is_bias) name = "zeros" | |
| 47 | − | return | |
| 48 | end if | ||
| 49 | |||
| 50 | |||
| 51 | !--------------------------------------------------------------------------- | ||
| 52 | ! Set default initialiser based on activation | ||
| 53 | !--------------------------------------------------------------------------- | ||
| 54 | − | if(trim(activation).eq."selu")then | |
| 55 | − | name = "lecun_normal" | |
| 56 | − | elseif(index(activation,"elu").ne.0)then | |
| 57 | − | name = "he_uniform" | |
| 58 | − | elseif(trim(activation).eq."batch")then | |
| 59 | − | name = "gaussian" | |
| 60 | else | ||
| 61 | − | name = "glorot_uniform" | |
| 62 | end if | ||
| 63 | |||
| 64 | − | end function get_default_initialiser | |
| 65 | !############################################################################### | ||
| 66 | |||
| 67 | |||
| 68 | !############################################################################### | ||
| 69 | − | function initialiser_setup(input, error) result(initialiser) | |
| 70 | !! Set up the initialiser function | ||
| 71 | implicit none | ||
| 72 | |||
| 73 | ! Arguments | ||
| 74 | class(base_init_type), allocatable :: initialiser | ||
| 75 | !! Initialiser function | ||
| 76 | class(*) :: input | ||
| 77 | !! Name of initialiser or initialiser object | ||
| 78 | integer, optional, intent(out) :: error | ||
| 79 | !! Error code | ||
| 80 | |||
| 81 | ! Local variables | ||
| 82 | character(256) :: err_msg | ||
| 83 | !! Error message | ||
| 84 | |||
| 85 | |||
| 86 | !--------------------------------------------------------------------------- | ||
| 87 | ! Set initialiser function | ||
| 88 | !--------------------------------------------------------------------------- | ||
| 89 | select type(input) | ||
| 90 | class is(base_init_type) | ||
| 91 | − | initialiser = input | |
| 92 | type is(character(*)) | ||
| 93 | − | select case(trim(to_lower(input))) | |
| 94 | case("glorot_uniform") | ||
| 95 | − | initialiser = glorot_uniform_init_type() | |
| 96 | case("glorot_normal") | ||
| 97 | − | initialiser = glorot_normal_init_type() | |
| 98 | case("he_uniform") | ||
| 99 | − | initialiser = he_uniform_init_type() | |
| 100 | case("he_normal") | ||
| 101 | − | initialiser = he_normal_init_type() | |
| 102 | case("lecun_uniform") | ||
| 103 | − | initialiser = lecun_uniform_init_type() | |
| 104 | case("lecun_normal") | ||
| 105 | − | initialiser = lecun_normal_init_type() | |
| 106 | case("ones") | ||
| 107 | − | initialiser = ones_init_type() | |
| 108 | case("zeros") | ||
| 109 | − | initialiser = zeros_init_type() | |
| 110 | case("ident") | ||
| 111 | − | initialiser = ident_init_type() | |
| 112 | case("gaussian") | ||
| 113 | − | initialiser = gaussian_init_type() | |
| 114 | case("normal") | ||
| 115 | − | initialiser = gaussian_init_type(name="normal") | |
| 116 | case default | ||
| 117 | − | if(present(error))then | |
| 118 | − | error = -1 | |
| 119 | − | return | |
| 120 | else | ||
| 121 | write(err_msg,'("Incorrect initialiser name given ''",A,"''")') & | ||
| 122 | − | trim(to_lower(input)) | |
| 123 | − | call stop_program(trim(err_msg)) | |
| 124 | − | return | |
| 125 | end if | ||
| 126 | end select | ||
| 127 | class default | ||
| 128 | − | if(present(error))then | |
| 129 | − | error = -1 | |
| 130 | − | return | |
| 131 | else | ||
| 132 | − | write(err_msg,'("Unknown input type given for initialiser setup")') | |
| 133 | − | call stop_program(trim(err_msg)) | |
| 134 | − | return | |
| 135 | end if | ||
| 136 | end select | ||
| 137 | |||
| 138 | − | end function initialiser_setup | |
| 139 | !############################################################################### | ||
| 140 | |||
| 141 | − | end module athena__initialiser | |
| 142 |