| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__msgpass_layer | ||
| 2 | !! Module containing the types and interfaces of a message passing layer | ||
| 3 | use coreutils, only: real32 | ||
| 4 | use graphstruc, only: graph_type | ||
| 5 | use athena__base_layer, only: learnable_layer_type | ||
| 6 | use athena__clipper, only: clip_type | ||
| 7 | use diffstruc, only: array_type | ||
| 8 | implicit none | ||
| 9 | |||
| 10 | |||
| 11 | private | ||
| 12 | |||
| 13 | public :: msgpass_layer_type | ||
| 14 | |||
| 15 | |||
| 16 | !------------------------------------------------------------------------------- | ||
| 17 | ! Message passing layer | ||
| 18 | !------------------------------------------------------------------------------- | ||
| 19 | type, abstract, extends(learnable_layer_type) :: msgpass_layer_type | ||
| 20 | !! Type for message passing layer with overloaded procedures | ||
| 21 | !! | ||
| 22 | !! This derived type contains the implementation of a message passing | ||
| 23 | !! layer. These are useful for graph neural networks and other models | ||
| 24 | !! that require message passing. | ||
| 25 | !! For graphs, the terms there are two common terms used seemingly | ||
| 26 | !! interchangeably in the literature: | ||
| 27 | !! - vertex/node - the individual elements in the graph | ||
| 28 | !! - edge - the connections between the nodes | ||
| 29 | !! Here, we use the term vertex to refer to the individual elements | ||
| 30 | !! in the graph and edge to refer to the connections between vertices. | ||
| 31 | integer, dimension(:), allocatable :: num_vertex_features | ||
| 32 | !! Number of vertex features for each time step | ||
| 33 | integer, dimension(:), allocatable :: num_edge_features | ||
| 34 | !! Number of edge features for each time step | ||
| 35 | integer :: num_time_steps | ||
| 36 | !! Number of time steps | ||
| 37 | integer :: num_output_vertex_features | ||
| 38 | !! Number of output vertex features | ||
| 39 | integer :: num_output_edge_features | ||
| 40 | !! Number of output edge features | ||
| 41 | integer :: num_outputs | ||
| 42 | !! Number of outputs (if output is not graph structure) | ||
| 43 | |||
| 44 | integer, dimension(:), allocatable :: num_params_msg | ||
| 45 | !! Number of learnable parameters for each message | ||
| 46 | integer :: num_params_readout | ||
| 47 | !! Number of learnable parameters for the readout | ||
| 48 | |||
| 49 | contains | ||
| 50 | ! procedure, pass(this) :: set_hyperparams => set_hyperparams_msgpass | ||
| 51 | ! !! Set the hyperparameters for message passing layer | ||
| 52 | procedure, pass(this) :: init => init_msgpass | ||
| 53 | !! Initialise message passing layer | ||
| 54 | ! procedure, pass(this) :: print => print_msgpass | ||
| 55 | ! !! Print the message passing layer | ||
| 56 | ! procedure, pass(this) :: read => read_msgpass | ||
| 57 | ! !! Read the message passing layer | ||
| 58 | procedure, pass(this) :: set_graph => set_graph_msgpass | ||
| 59 | |||
| 60 | |||
| 61 | |||
| 62 | ! procedure, pass(this) :: reduce => layer_reduction | ||
| 63 | ! !! Reduce message passing layer | ||
| 64 | ! procedure, pass(this) :: merge => layer_merge | ||
| 65 | ! !! Merge message passing layer | ||
| 66 | procedure, pass(this) :: get_num_params => get_num_params_msgpass | ||
| 67 | !! Get the number of learnable parameters for message passing layer | ||
| 68 | |||
| 69 | procedure, pass(this) :: forward => forward_msgpass | ||
| 70 | !! Forward pass for message passing layer | ||
| 71 | |||
| 72 | procedure(update_message_msgpass), deferred, pass(this) :: update_message | ||
| 73 | !! Update the message | ||
| 74 | procedure(update_readout_msgpass), deferred, pass(this) :: update_readout | ||
| 75 | !! Update the readout | ||
| 76 | end type msgpass_layer_type | ||
| 77 | |||
| 78 | ! Interface for setting up the MPNN layer | ||
| 79 | !----------------------------------------------------------------------------- | ||
| 80 | interface msgpass_layer_type | ||
| 81 | !! Interface for setting up the MPNN layer | ||
| 82 | module function layer_setup( & | ||
| 83 | num_features, num_time_steps, & | ||
| 84 | verbose & | ||
| 85 | ) result(layer) | ||
| 86 | !! Set up the MPNN layer | ||
| 87 | !!! MAKE THESE ASSUMED RANK | ||
| 88 | integer, dimension(2), intent(in) :: num_features | ||
| 89 | !! Number of features | ||
| 90 | integer, intent(in) :: num_time_steps | ||
| 91 | !! Number of time steps | ||
| 92 | integer, optional, intent(in) :: verbose | ||
| 93 | !! Verbosity level | ||
| 94 | class(msgpass_layer_type), allocatable :: layer | ||
| 95 | !! Instance of the message passing layer | ||
| 96 | end function layer_setup | ||
| 97 | end interface msgpass_layer_type | ||
| 98 | |||
| 99 | ! Interface for handling the message passing layer parameters | ||
| 100 | !----------------------------------------------------------------------------- | ||
| 101 | interface | ||
| 102 | !! Interfaces for handling learnable parameters and gradients | ||
| 103 | pure module function get_num_params_msgpass(this) result(num_params) | ||
| 104 | !! Get the number of learnable parameters for the message passing layer | ||
| 105 | class(msgpass_layer_type), intent(in) :: this | ||
| 106 | !! Instance of the message passing layer | ||
| 107 | integer :: num_params | ||
| 108 | !! Number of learnable parameters | ||
| 109 | end function get_num_params_msgpass | ||
| 110 | |||
| 111 | module subroutine set_graph_msgpass(this, graph) | ||
| 112 | !! Set the graph structure of the input data | ||
| 113 | class(msgpass_layer_type), intent(inout) :: this | ||
| 114 | !! Instance of the layer | ||
| 115 | type(graph_type), dimension(:), intent(in) :: graph | ||
| 116 | !! Graph structure of input data | ||
| 117 | end subroutine set_graph_msgpass | ||
| 118 | end interface | ||
| 119 | |||
| 120 | ! ! Interface for reducing and merging layers | ||
| 121 | ! !----------------------------------------------------------------------------- | ||
| 122 | ! interface | ||
| 123 | ! !! Interfaces for reducing and merging layers | ||
| 124 | ! module subroutine layer_reduction(this, rhs) | ||
| 125 | ! !! Reduce the layer | ||
| 126 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 127 | ! !! Instance of the message passing layer | ||
| 128 | ! class(learnable_layer_type), intent(in) :: rhs | ||
| 129 | ! !! Instance of the learnable layer (expects a message passing layer) | ||
| 130 | ! end subroutine layer_reduction | ||
| 131 | |||
| 132 | ! module subroutine layer_merge(this, input) | ||
| 133 | ! !! Merge the layer | ||
| 134 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 135 | ! !! Instance of the message passing layer | ||
| 136 | ! class(learnable_layer_type), intent(in) :: input | ||
| 137 | ! !! Instance of the learnable layer (expects a message passing layer) | ||
| 138 | ! end subroutine layer_merge | ||
| 139 | ! end interface | ||
| 140 | |||
| 141 | ! Interface for handling forward and backward passes | ||
| 142 | !----------------------------------------------------------------------------- | ||
| 143 | interface | ||
| 144 | module subroutine forward_msgpass(this, input) | ||
| 145 | !! Forward pass for the message passing layer | ||
| 146 | class(msgpass_layer_type), intent(inout) :: this | ||
| 147 | !! Instance of the layer type | ||
| 148 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 149 | !! Input data (i.e. vertex and edge features) | ||
| 150 | end subroutine forward_msgpass | ||
| 151 | end interface | ||
| 152 | |||
| 153 | ! Interface for handling graphs and outputs | ||
| 154 | !----------------------------------------------------------------------------- | ||
| 155 | interface | ||
| 156 | !! Interfaces for handling graphs and outputs, and initialising the layer | ||
| 157 | ! module subroutine print_msgpass(this, file) | ||
| 158 | ! !! Print the message passing layer | ||
| 159 | ! class(msgpass_layer_type), intent(in) :: this | ||
| 160 | ! !! Instance of the message passing layer | ||
| 161 | ! character(*), intent(in) :: file | ||
| 162 | ! !! File to print to | ||
| 163 | ! end subroutine print_msgpass | ||
| 164 | ! module subroutine read_msgpass(this, unit, verbose) | ||
| 165 | ! !! Read the message passing layer | ||
| 166 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 167 | ! !! Instance of the message passing layer | ||
| 168 | ! integer, intent(in) :: unit | ||
| 169 | ! !! Unit to read from | ||
| 170 | ! integer, optional, intent(in) :: verbose | ||
| 171 | ! !! Verbosity level | ||
| 172 | ! end subroutine read_msgpass | ||
| 173 | module subroutine init_msgpass(this, input_shape, verbose) | ||
| 174 | !! Initialise the message passing layer | ||
| 175 | class(msgpass_layer_type), intent(inout) :: this | ||
| 176 | !! Instance of the message passing layer | ||
| 177 | integer, dimension(:), intent(in) :: input_shape | ||
| 178 | !! Input shape | ||
| 179 | integer, optional, intent(in) :: verbose | ||
| 180 | !! Verbosity level | ||
| 181 | end subroutine init_msgpass | ||
| 182 | ! module subroutine set_hyperparams_msgpass( & | ||
| 183 | ! this, num_features, num_time_steps, num_outputs, verbose & | ||
| 184 | ! ) | ||
| 185 | ! !! Set the hyperparameters for the message passing layer | ||
| 186 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 187 | ! !! Instance of the message passing layer | ||
| 188 | ! integer, dimension(2), intent(in) :: num_features | ||
| 189 | ! !! Number of features | ||
| 190 | ! integer, intent(in) :: num_time_steps | ||
| 191 | ! !! Number of time steps | ||
| 192 | ! integer, intent(in) :: num_outputs | ||
| 193 | ! !! Number of outputs | ||
| 194 | ! integer, optional, intent(in) :: verbose | ||
| 195 | ! !! Verbosity level | ||
| 196 | ! end subroutine set_hyperparams_msgpass | ||
| 197 | end interface | ||
| 198 | !------------------------------------------------------------------------------- | ||
| 199 | |||
| 200 | |||
| 201 | !------------------------------------------------------------------------------! | ||
| 202 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 203 | !------------------------------------------------------------------------------! | ||
| 204 | |||
| 205 | |||
| 206 | |||
| 207 | interface | ||
| 208 | !! interface for the message forward and backward passes | ||
| 209 | module subroutine update_message_msgpass(this, input) | ||
| 210 | !! Update the message | ||
| 211 | class(msgpass_layer_type), intent(inout), target :: this | ||
| 212 | !! Instance of the message passing layer | ||
| 213 | class(array_type), dimension(:,:), intent(in), target :: input | ||
| 214 | !! Input data (i.e. vertex and edge features) | ||
| 215 | end subroutine update_message_msgpass | ||
| 216 | end interface | ||
| 217 | |||
| 218 | interface | ||
| 219 | !! interface for the readout forward and backward passes | ||
| 220 | module subroutine update_readout_msgpass(this) | ||
| 221 | !! Update the message | ||
| 222 | class(msgpass_layer_type), intent(inout), target :: this | ||
| 223 | !! Instance of the message passing layer | ||
| 224 | end subroutine update_readout_msgpass | ||
| 225 | end interface | ||
| 226 | |||
| 227 | |||
| 228 | |||
| 229 | − | end module athena__msgpass_layer | |
| 230 |