| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__initialiser_ident | ||
| 2 | !! Module containing the implementation of the identity initialiser | ||
| 3 | !! | ||
| 4 | !! This module contains the implementation of the identity initialiser | ||
| 5 | !! for the weights and biases of a layer | ||
| 6 | use coreutils, only: real32, stop_program | ||
| 7 | use athena__misc_types, only: base_init_type | ||
| 8 | implicit none | ||
| 9 | |||
| 10 | |||
| 11 | private | ||
| 12 | |||
| 13 | public :: ident_init_type | ||
| 14 | |||
| 15 | |||
| 16 | type, extends(base_init_type) :: ident_init_type | ||
| 17 | !! Type for the identity initialiser | ||
| 18 | contains | ||
| 19 | procedure, pass(this) :: initialise => ident_initialise | ||
| 20 | !! Initialise the weights and biases using the identity matrix | ||
| 21 | end type ident_init_type | ||
| 22 | |||
| 23 | |||
| 24 | interface ident_init_type | ||
| 25 | module function initialiser_ident_setup() result(initialiser) | ||
| 26 | !! Interface for the Identity initialiser | ||
| 27 | type(ident_init_type) :: initialiser | ||
| 28 | !! Identity initialiser object | ||
| 29 | end function initialiser_ident_setup | ||
| 30 | end interface ident_init_type | ||
| 31 | |||
| 32 | |||
| 33 | |||
| 34 | contains | ||
| 35 | |||
| 36 | !############################################################################### | ||
| 37 | − | module function initialiser_ident_setup() result(initialiser) | |
| 38 | !! Interface for the Identity initialiser | ||
| 39 | implicit none | ||
| 40 | |||
| 41 | type(ident_init_type) :: initialiser | ||
| 42 | !! Identity initialiser object | ||
| 43 | |||
| 44 | − | initialiser%name = "ident" | |
| 45 | |||
| 46 | − | end function initialiser_ident_setup | |
| 47 | !############################################################################### | ||
| 48 | |||
| 49 | |||
| 50 | !############################################################################### | ||
| 51 | − | subroutine ident_initialise(this, input, fan_in, fan_out, spacing) | |
| 52 | !! Initialise the weights and biases using the identity matrix | ||
| 53 | implicit none | ||
| 54 | |||
| 55 | ! Arguments | ||
| 56 | class(ident_init_type), intent(inout) :: this | ||
| 57 | !! Instance of the identity initialiser | ||
| 58 | real(real32), dimension(..), intent(out) :: input | ||
| 59 | !! Weights and biases to initialise | ||
| 60 | integer, optional, intent(in) :: fan_in, fan_out | ||
| 61 | !! Number of input and output parameters | ||
| 62 | integer, dimension(:), optional, intent(in) :: spacing | ||
| 63 | !! Spacing of the input and output units | ||
| 64 | |||
| 65 | ! Local variables | ||
| 66 | integer :: i, j | ||
| 67 | !! Loop index | ||
| 68 | integer :: ndim | ||
| 69 | !! Number of dimensions | ||
| 70 | − | integer, dimension(:), allocatable :: iprime, iprime2 | |
| 71 | !! Index variables | ||
| 72 | |||
| 73 | |||
| 74 | − | if(all(shape(input).ne.size(input,1)))then | |
| 75 | call stop_program( & | ||
| 76 | 'A non-square tensor cannot be initialised as an identity matrix' & | ||
| 77 | − | ) | |
| 78 | − | return | |
| 79 | end if | ||
| 80 | |||
| 81 | select rank(input) | ||
| 82 | rank(0) | ||
| 83 | − | input = 1._real32 | |
| 84 | rank(1) | ||
| 85 | − | if(size(input).ne.1)then | |
| 86 | − | if(.not.present(spacing))then | |
| 87 | call stop_program( & | ||
| 88 | 'A vector of length greater than 1 cannot be & | ||
| 89 | &initialised as an identity matrix' & | ||
| 90 | − | ) | |
| 91 | − | return | |
| 92 | else | ||
| 93 | − | ndim = size(spacing) | |
| 94 | − | if(ndim.eq.1)then | |
| 95 | − | do i = 1, size(input)/spacing(1) | |
| 96 | − | input(1 + ( i - 1 ) * ( spacing(1) + 1) ) = 1._real32 | |
| 97 | end do | ||
| 98 | − | elseif(ndim.gt.1)then | |
| 99 | − | allocate(iprime(ndim)) | |
| 100 | − | allocate(iprime2(ndim)) | |
| 101 | − | iprime2 = 0 | |
| 102 | − | iprime2(1) = 1 | |
| 103 | − | do i = 1, size(input)/spacing(1) | |
| 104 | − | iprime(ndim) = & | |
| 105 | mod( & | ||
| 106 | − | (i - 1) / product( spacing(:ndim-1) ), & | |
| 107 | − | product(spacing(:ndim)) & | |
| 108 | − | ) | |
| 109 | − | iprime(ndim) = iprime(ndim) * product(spacing(:ndim-1)) | |
| 110 | − | do j = ndim - 1, 1, -1 | |
| 111 | − | if(sum(iprime(j+1:)).eq.0) then | |
| 112 | − | iprime(j) = 0 | |
| 113 | else | ||
| 114 | − | iprime(j) = & | |
| 115 | mod( & | ||
| 116 | (i - 1), & | ||
| 117 | − | sum(iprime(j+1:)) & | |
| 118 | − | ) / product(spacing(:j-1)) | |
| 119 | end if | ||
| 120 | − | iprime(j) = iprime(j) * product(spacing(:j-1)) | |
| 121 | end do | ||
| 122 | − | input(1 + sum(iprime * ( spacing(1) + iprime2 ))) = 1._real32 | |
| 123 | end do | ||
| 124 | end if | ||
| 125 | end if | ||
| 126 | else | ||
| 127 | − | input = 1._real32 | |
| 128 | end if | ||
| 129 | rank(2) | ||
| 130 | − | input = 0._real32 | |
| 131 | − | do i = 1, size(input,1) | |
| 132 | − | input(i,i) = 1._real32 | |
| 133 | end do | ||
| 134 | rank(3) | ||
| 135 | − | input = 0._real32 | |
| 136 | − | do i = 1, size(input,1) | |
| 137 | − | input(i,i,i) = 1._real32 | |
| 138 | end do | ||
| 139 | rank(4) | ||
| 140 | − | input = 0._real32 | |
| 141 | − | do i = 1, size(input,1) | |
| 142 | − | input(i,i,i,i) = 1._real32 | |
| 143 | end do | ||
| 144 | rank(5) | ||
| 145 | − | input = 0._real32 | |
| 146 | − | do i = 1, size(input,1) | |
| 147 | − | input(i,i,i,i,i) = 1._real32 | |
| 148 | end do | ||
| 149 | rank(6) | ||
| 150 | − | input = 0._real32 | |
| 151 | − | do i = 1, size(input,1) | |
| 152 | − | input(i,i,i,i,i,i) = 1._real32 | |
| 153 | end do | ||
| 154 | end select | ||
| 155 | |||
| 156 | − | end subroutine ident_initialise | |
| 157 | !############################################################################### | ||
| 158 | |||
| 159 | − | end module athena__initialiser_ident | |
| 160 |