| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule(athena__msgpass_layer) athena__msgpass_layer_submodule | ||
| 2 | !! Submodule containing implementations for a message passing layer | ||
| 3 | implicit none | ||
| 4 | |||
| 5 | |||
| 6 | |||
| 7 | contains | ||
| 8 | |||
| 9 | !############################################################################### | ||
| 10 | − | pure module function get_num_params_msgpass(this) result(num_params) | |
| 11 | !! Get the number of learnable parameters in the layer | ||
| 12 | implicit none | ||
| 13 | |||
| 14 | ! Arguments | ||
| 15 | class(msgpass_layer_type), intent(in) :: this | ||
| 16 | !! Instance of the layer type | ||
| 17 | integer :: num_params | ||
| 18 | !! Number of learnable parameters | ||
| 19 | |||
| 20 | ! Local variables | ||
| 21 | integer :: t | ||
| 22 | !! Time step | ||
| 23 | |||
| 24 | − | num_params = sum(this%num_params_msg) + this%num_params_readout | |
| 25 | − | end function get_num_params_msgpass | |
| 26 | !############################################################################### | ||
| 27 | |||
| 28 | |||
| 29 | !##############################################################################! | ||
| 30 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 31 | !##############################################################################! | ||
| 32 | |||
| 33 | |||
| 34 | !############################################################################### | ||
| 35 | − | module function layer_setup( & | |
| 36 | num_features, num_time_steps, & | ||
| 37 | verbose & | ||
| 38 | − | ) result(layer) | |
| 39 | !! Procedure to set up the layer | ||
| 40 | implicit none | ||
| 41 | |||
| 42 | ! Arguments | ||
| 43 | integer, dimension(2), intent(in) :: num_features | ||
| 44 | !! Number of features | ||
| 45 | integer, intent(in) :: num_time_steps | ||
| 46 | !! Number of time steps | ||
| 47 | integer, optional, intent(in) :: verbose | ||
| 48 | !! Verbosity level | ||
| 49 | |||
| 50 | class(msgpass_layer_type), allocatable :: layer | ||
| 51 | !! Instance of the layer type | ||
| 52 | |||
| 53 | ! Local variables | ||
| 54 | integer :: verbose_ = 0 | ||
| 55 | !! Verbosity level | ||
| 56 | |||
| 57 | |||
| 58 | − | end function layer_setup | |
| 59 | !############################################################################### | ||
| 60 | |||
| 61 | |||
| 62 | !############################################################################### | ||
| 63 | ! module subroutine set_hyperparams_msgpass( & | ||
| 64 | ! this, & | ||
| 65 | ! num_features, num_time_steps, num_outputs, verbose & | ||
| 66 | ! ) | ||
| 67 | ! !! Set the hyperparameters for the layer | ||
| 68 | ! implicit none | ||
| 69 | |||
| 70 | ! ! Arguments | ||
| 71 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 72 | ! !! Instance of the layer type | ||
| 73 | ! integer, dimension(2), intent(in) :: num_features | ||
| 74 | ! !! Number of features | ||
| 75 | ! integer, intent(in) :: num_time_steps | ||
| 76 | ! !! Number of time steps | ||
| 77 | ! integer, intent(in) :: num_outputs | ||
| 78 | ! !! Number of output features | ||
| 79 | ! integer, optional, intent(in) :: verbose | ||
| 80 | ! !! Verbosity level | ||
| 81 | |||
| 82 | |||
| 83 | ! this%name = 'msgpass' | ||
| 84 | ! this%type = 'msgp' | ||
| 85 | ! this%input_rank = 1 | ||
| 86 | ! this%num_outputs = num_outputs | ||
| 87 | ! this%num_time_steps = num_time_steps | ||
| 88 | ! this%num_vertex_features = num_features(1) | ||
| 89 | ! this%num_edge_features = num_features(2) | ||
| 90 | |||
| 91 | ! end subroutine set_hyperparams_msgpass | ||
| 92 | !############################################################################### | ||
| 93 | |||
| 94 | |||
| 95 | !############################################################################### | ||
| 96 | − | module subroutine init_msgpass(this, input_shape, verbose) | |
| 97 | !! Initialise the layer | ||
| 98 | implicit none | ||
| 99 | |||
| 100 | ! Arguments | ||
| 101 | class(msgpass_layer_type), intent(inout) :: this | ||
| 102 | !! Instance of the layer type | ||
| 103 | integer, dimension(:), intent(in) :: input_shape | ||
| 104 | !! Input shape | ||
| 105 | integer, optional, intent(in) :: verbose | ||
| 106 | !! Verbosity level | ||
| 107 | |||
| 108 | ! Local variables | ||
| 109 | integer :: t | ||
| 110 | !! Time step | ||
| 111 | integer :: verbose_ = 0 | ||
| 112 | !! Verbosity level | ||
| 113 | integer :: num_params_message | ||
| 114 | !! Number of parameters in the message | ||
| 115 | integer :: num_params_readout | ||
| 116 | !! Number of parameters in the readout | ||
| 117 | |||
| 118 | |||
| 119 | !--------------------------------------------------------------------------- | ||
| 120 | ! Initialise optional arguments | ||
| 121 | !--------------------------------------------------------------------------- | ||
| 122 | − | if (present(verbose)) verbose_ = verbose | |
| 123 | |||
| 124 | |||
| 125 | !--------------------------------------------------------------------------- | ||
| 126 | ! Initialise number of inputs | ||
| 127 | !--------------------------------------------------------------------------- | ||
| 128 | − | this%input_shape = [ 1 ] !input_shape | |
| 129 | ! this%output_shape = this%num_outputs | ||
| 130 | !if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) | ||
| 131 | − | this%num_params = this%get_num_params() | |
| 132 | |||
| 133 | |||
| 134 | !--------------------------------------------------------------------------- | ||
| 135 | ! Allocate arrays | ||
| 136 | !--------------------------------------------------------------------------- | ||
| 137 | − | if(allocated(this%output)) deallocate(this%output) | |
| 138 | |||
| 139 | − | end subroutine init_msgpass | |
| 140 | !############################################################################### | ||
| 141 | |||
| 142 | |||
| 143 | !############################################################################### | ||
| 144 | − | module subroutine set_graph_msgpass(this, graph) | |
| 145 | !! Set the graph structure of the input data | ||
| 146 | implicit none | ||
| 147 | |||
| 148 | ! Arguments | ||
| 149 | class(msgpass_layer_type), intent(inout) :: this | ||
| 150 | !! Instance of the layer | ||
| 151 | type(graph_type), dimension(:), intent(in) :: graph | ||
| 152 | !! Graph structure of input data | ||
| 153 | |||
| 154 | ! Local variables | ||
| 155 | integer :: s | ||
| 156 | !! Loop indices | ||
| 157 | |||
| 158 | − | if(allocated(this%graph))then | |
| 159 | − | if(size(this%graph).ne.size(graph))then | |
| 160 | − | deallocate(this%graph) | |
| 161 | − | allocate(this%graph(size(graph))) | |
| 162 | end if | ||
| 163 | else | ||
| 164 | − | allocate(this%graph(size(graph))) | |
| 165 | end if | ||
| 166 | − | do s = 1, size(graph) | |
| 167 | − | this%graph(s)%adj_ia = graph(s)%adj_ia | |
| 168 | − | this%graph(s)%adj_ja = graph(s)%adj_ja | |
| 169 | − | this%graph(s)%edge_weights = graph(s)%edge_weights | |
| 170 | − | this%graph(s)%num_edges = graph(s)%num_edges | |
| 171 | − | this%graph(s)%num_vertices = graph(s)%num_vertices | |
| 172 | end do | ||
| 173 | |||
| 174 | − | end subroutine set_graph_msgpass | |
| 175 | !############################################################################### | ||
| 176 | |||
| 177 | |||
| 178 | !##############################################################################! | ||
| 179 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 180 | !##############################################################################! | ||
| 181 | |||
| 182 | |||
| 183 | !############################################################################### | ||
| 184 | − | module subroutine forward_msgpass(this, input) | |
| 185 | !! Forward propagation for the layer | ||
| 186 | implicit none | ||
| 187 | |||
| 188 | ! Arguments | ||
| 189 | class(msgpass_layer_type), intent(inout) :: this | ||
| 190 | !! Instance of the layer type | ||
| 191 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 192 | !! Input data (i.e. vertex and edge features) | ||
| 193 | |||
| 194 | |||
| 195 | − | call this%update_message(input) | |
| 196 | − | call this%update_readout() | |
| 197 | |||
| 198 | − | end subroutine forward_msgpass | |
| 199 | !############################################################################### | ||
| 200 | |||
| 201 | end submodule athena__msgpass_layer_submodule | ||
| 202 |