GCC Code Coverage Report


Directory: src/athena/
File: athena_msgpass_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__msgpass_layer
2 !! Module containing the types and interfaces of a message passing layer
3 use coreutils, only: real32
4 use graphstruc, only: graph_type
5 use athena__base_layer, only: learnable_layer_type
6 use athena__clipper, only: clip_type
7 use diffstruc, only: array_type
8 implicit none
9
10
11 private
12
13 public :: msgpass_layer_type
14
15
16 !-------------------------------------------------------------------------------
17 ! Message passing layer
18 !-------------------------------------------------------------------------------
19 type, abstract, extends(learnable_layer_type) :: msgpass_layer_type
20 !! Type for message passing layer with overloaded procedures
21 !!
22 !! This derived type contains the implementation of a message passing
23 !! layer. These are useful for graph neural networks and other models
24 !! that require message passing.
25 !! For graphs, the terms there are two common terms used seemingly
26 !! interchangeably in the literature:
27 !! - vertex/node - the individual elements in the graph
28 !! - edge - the connections between the nodes
29 !! Here, we use the term vertex to refer to the individual elements
30 !! in the graph and edge to refer to the connections between vertices.
31 integer, dimension(:), allocatable :: num_vertex_features
32 !! Number of vertex features for each time step
33 integer, dimension(:), allocatable :: num_edge_features
34 !! Number of edge features for each time step
35 integer :: num_time_steps
36 !! Number of time steps
37 integer :: num_output_vertex_features
38 !! Number of output vertex features
39 integer :: num_output_edge_features
40 !! Number of output edge features
41 integer :: num_outputs
42 !! Number of outputs (if output is not graph structure)
43
44 integer, dimension(:), allocatable :: num_params_msg
45 !! Number of learnable parameters for each message
46 integer :: num_params_readout
47 !! Number of learnable parameters for the readout
48
49 contains
50 ! procedure, pass(this) :: set_hyperparams => set_hyperparams_msgpass
51 ! !! Set the hyperparameters for message passing layer
52 procedure, pass(this) :: init => init_msgpass
53 !! Initialise message passing layer
54 ! procedure, pass(this) :: print => print_msgpass
55 ! !! Print the message passing layer
56 ! procedure, pass(this) :: read => read_msgpass
57 ! !! Read the message passing layer
58 procedure, pass(this) :: set_graph => set_graph_msgpass
59
60
61
62 ! procedure, pass(this) :: reduce => layer_reduction
63 ! !! Reduce message passing layer
64 ! procedure, pass(this) :: merge => layer_merge
65 ! !! Merge message passing layer
66 procedure, pass(this) :: get_num_params => get_num_params_msgpass
67 !! Get the number of learnable parameters for message passing layer
68
69 procedure, pass(this) :: forward => forward_msgpass
70 !! Forward pass for message passing layer
71
72 procedure(update_message_msgpass), deferred, pass(this) :: update_message
73 !! Update the message
74 procedure(update_readout_msgpass), deferred, pass(this) :: update_readout
75 !! Update the readout
76 end type msgpass_layer_type
77
78 ! Interface for setting up the MPNN layer
79 !-----------------------------------------------------------------------------
80 interface msgpass_layer_type
81 !! Interface for setting up the MPNN layer
82 module function layer_setup( &
83 num_features, num_time_steps, &
84 verbose &
85 ) result(layer)
86 !! Set up the MPNN layer
87 !!! MAKE THESE ASSUMED RANK
88 integer, dimension(2), intent(in) :: num_features
89 !! Number of features
90 integer, intent(in) :: num_time_steps
91 !! Number of time steps
92 integer, optional, intent(in) :: verbose
93 !! Verbosity level
94 class(msgpass_layer_type), allocatable :: layer
95 !! Instance of the message passing layer
96 end function layer_setup
97 end interface msgpass_layer_type
98
99 ! Interface for handling the message passing layer parameters
100 !-----------------------------------------------------------------------------
101 interface
102 !! Interfaces for handling learnable parameters and gradients
103 pure module function get_num_params_msgpass(this) result(num_params)
104 !! Get the number of learnable parameters for the message passing layer
105 class(msgpass_layer_type), intent(in) :: this
106 !! Instance of the message passing layer
107 integer :: num_params
108 !! Number of learnable parameters
109 end function get_num_params_msgpass
110
111 module subroutine set_graph_msgpass(this, graph)
112 !! Set the graph structure of the input data
113 class(msgpass_layer_type), intent(inout) :: this
114 !! Instance of the layer
115 type(graph_type), dimension(:), intent(in) :: graph
116 !! Graph structure of input data
117 end subroutine set_graph_msgpass
118 end interface
119
120 ! ! Interface for reducing and merging layers
121 ! !-----------------------------------------------------------------------------
122 ! interface
123 ! !! Interfaces for reducing and merging layers
124 ! module subroutine layer_reduction(this, rhs)
125 ! !! Reduce the layer
126 ! class(msgpass_layer_type), intent(inout) :: this
127 ! !! Instance of the message passing layer
128 ! class(learnable_layer_type), intent(in) :: rhs
129 ! !! Instance of the learnable layer (expects a message passing layer)
130 ! end subroutine layer_reduction
131
132 ! module subroutine layer_merge(this, input)
133 ! !! Merge the layer
134 ! class(msgpass_layer_type), intent(inout) :: this
135 ! !! Instance of the message passing layer
136 ! class(learnable_layer_type), intent(in) :: input
137 ! !! Instance of the learnable layer (expects a message passing layer)
138 ! end subroutine layer_merge
139 ! end interface
140
141 ! Interface for handling forward and backward passes
142 !-----------------------------------------------------------------------------
143 interface
144 module subroutine forward_msgpass(this, input)
145 !! Forward pass for the message passing layer
146 class(msgpass_layer_type), intent(inout) :: this
147 !! Instance of the layer type
148 class(array_type), dimension(:,:), intent(in) :: input
149 !! Input data (i.e. vertex and edge features)
150 end subroutine forward_msgpass
151 end interface
152
153 ! Interface for handling graphs and outputs
154 !-----------------------------------------------------------------------------
155 interface
156 !! Interfaces for handling graphs and outputs, and initialising the layer
157 ! module subroutine print_msgpass(this, file)
158 ! !! Print the message passing layer
159 ! class(msgpass_layer_type), intent(in) :: this
160 ! !! Instance of the message passing layer
161 ! character(*), intent(in) :: file
162 ! !! File to print to
163 ! end subroutine print_msgpass
164 ! module subroutine read_msgpass(this, unit, verbose)
165 ! !! Read the message passing layer
166 ! class(msgpass_layer_type), intent(inout) :: this
167 ! !! Instance of the message passing layer
168 ! integer, intent(in) :: unit
169 ! !! Unit to read from
170 ! integer, optional, intent(in) :: verbose
171 ! !! Verbosity level
172 ! end subroutine read_msgpass
173 module subroutine init_msgpass(this, input_shape, verbose)
174 !! Initialise the message passing layer
175 class(msgpass_layer_type), intent(inout) :: this
176 !! Instance of the message passing layer
177 integer, dimension(:), intent(in) :: input_shape
178 !! Input shape
179 integer, optional, intent(in) :: verbose
180 !! Verbosity level
181 end subroutine init_msgpass
182 ! module subroutine set_hyperparams_msgpass( &
183 ! this, num_features, num_time_steps, num_outputs, verbose &
184 ! )
185 ! !! Set the hyperparameters for the message passing layer
186 ! class(msgpass_layer_type), intent(inout) :: this
187 ! !! Instance of the message passing layer
188 ! integer, dimension(2), intent(in) :: num_features
189 ! !! Number of features
190 ! integer, intent(in) :: num_time_steps
191 ! !! Number of time steps
192 ! integer, intent(in) :: num_outputs
193 ! !! Number of outputs
194 ! integer, optional, intent(in) :: verbose
195 ! !! Verbosity level
196 ! end subroutine set_hyperparams_msgpass
197 end interface
198 !-------------------------------------------------------------------------------
199
200
201 !------------------------------------------------------------------------------!
202 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
203 !------------------------------------------------------------------------------!
204
205
206
207 interface
208 !! interface for the message forward and backward passes
209 module subroutine update_message_msgpass(this, input)
210 !! Update the message
211 class(msgpass_layer_type), intent(inout), target :: this
212 !! Instance of the message passing layer
213 class(array_type), dimension(:,:), intent(in), target :: input
214 !! Input data (i.e. vertex and edge features)
215 end subroutine update_message_msgpass
216 end interface
217
218 interface
219 !! interface for the readout forward and backward passes
220 module subroutine update_readout_msgpass(this)
221 !! Update the message
222 class(msgpass_layer_type), intent(inout), target :: this
223 !! Instance of the message passing layer
224 end subroutine update_readout_msgpass
225 end interface
226
227
228
229 end module athena__msgpass_layer
230