GCC Code Coverage Report


Directory: src/athena/
File: athena_duvenaud_msgpass_layer.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__duvenaud_msgpass_layer
2 !! Module implementing Duvenaud message passing for molecular graphs
3 !!
4 !! This module implements the graph neural network architecture from
5 !! Duvenaud et al. (2015) for learning on molecular graphs with both
6 !! vertex (node) and edge features.
7 !!
8 !! Mathematical operation (per time step t):
9 !! \[ h_v^{(t+1)} = \sigma\left( h_v^{(t)} + \sum_{u \in \mathcal{N}(v)} M(h_v^{(t)}, h_u^{(t)}, e_{vu}) \right) \]
10 !!
11 !! Graph readout (aggregation to fixed-size vector):
12 !! \[ h_{\text{graph}} = \sigma_{\text{readout}}\left( \sum_{d=1}^D \sum_{v:\deg(v)=d} W_d h_v^{(T)} \right) \]
13 !!
14 !! where \( M \) is a learned message function, \( \sigma \) is activation function,
15 !! \( \mathcal{N}(v) \) are neighbors of \( v \), \( e_{vu} \) are edge features, \( W_d \) are
16 !! degree-specific weight matrices, and \( D \) is max vertex degree.
17 !!
18 !! Reference: Duvenaud et al. (2015), NeurIPS
19 use coreutils, only: real32
20 use graphstruc, only: graph_type
21 use athena__misc_types, only: base_actv_type, base_init_type
22 use diffstruc, only: array_type, sum, matmul, operator(+)
23 use athena__base_layer, only: base_layer_type
24 use athena__msgpass_layer, only: msgpass_layer_type
25 use athena__diffstruc_extd, only: duvenaud_propagate, duvenaud_update
26 implicit none
27
28
29 private
30
31 public :: duvenaud_msgpass_layer_type
32 public :: read_duvenaud_msgpass_layer
33
34
35 !-------------------------------------------------------------------------------
36 ! Message passing layer
37 !-------------------------------------------------------------------------------
38 type, extends(msgpass_layer_type) :: duvenaud_msgpass_layer_type
39
40 integer :: min_vertex_degree = 1
41 integer :: max_vertex_degree = 0
42 !! Maximum vertex degree
43
44 class(base_actv_type), allocatable :: activation_readout
45 !! Activation function
46 type(array_type), allocatable, dimension(:,:) :: z
47 type(array_type), allocatable, dimension(:,:) :: z_readout
48 !! Input gradients
49
50 contains
51 procedure, pass(this) :: get_num_params => get_num_params_duvenaud
52 !! Get the number of parameters for the message passing layer
53 procedure, pass(this) :: set_hyperparams => set_hyperparams_duvenaud
54 !! Set the hyperparameters for the message passing layer
55 procedure, pass(this) :: init => init_duvenaud
56 !! Initialise the message passing layer
57 procedure, pass(this) :: print_to_unit => print_to_unit_duvenaud
58 ! !! Print the message passing layer
59 procedure, pass(this) :: read => read_duvenaud
60 !! Read the message passing layer
61
62 procedure, pass(this) :: set_graph => set_graph_duvenaud
63 !! Set the graph for the message passing layer
64
65 procedure, pass(this) :: update_message => update_message_duvenaud
66 !! Update the message
67
68 procedure, pass(this) :: update_readout => update_readout_duvenaud
69 !! Update the readout
70
71 final :: finalise_duvenaud
72 !! Finalise the message passing layer
73 end type duvenaud_msgpass_layer_type
74
75 ! Interface for setting up the MPNN layer
76 !-----------------------------------------------------------------------------
77 interface duvenaud_msgpass_layer_type
78 !! Interface for setting up the MPNN layer
79 module function layer_setup( &
80 num_vertex_features, num_edge_features, num_time_steps, &
81 max_vertex_degree, &
82 num_outputs, &
83 min_vertex_degree, &
84 message_activation, &
85 readout_activation, &
86 kernel_initialiser, &
87 verbose &
88 ) result(layer)
89 !! Set up the message passing layer
90 integer, dimension(:), intent(in) :: num_vertex_features
91 !! Number of vertex features
92 integer, dimension(:), intent(in) :: num_edge_features
93 !! Number of edge features
94 integer, intent(in) :: num_time_steps
95 !! Number of time steps
96 integer, intent(in) :: max_vertex_degree
97 !! Maximum vertex degree
98 integer, intent(in) :: num_outputs
99 !! Number of outputs
100 integer, optional, intent(in) :: min_vertex_degree
101 !! Minimum vertex degree
102 class(*), optional, intent(in) :: message_activation, &
103 readout_activation
104 !! Message and readout activation functions
105 character(*), optional, intent(in) :: kernel_initialiser
106 !!! Kernel initialiser
107 integer, optional, intent(in) :: verbose
108 !! Verbosity level
109 type(duvenaud_msgpass_layer_type) :: layer
110 !! Instance of the message passing layer
111 end function layer_setup
112 end interface duvenaud_msgpass_layer_type
113
114 character(len=*), parameter :: default_message_actv_name = "sigmoid"
115 character(len=*), parameter :: default_readout_actv_name = "softmax"
116
117
118
119 contains
120
121 !###############################################################################
122 subroutine finalise_duvenaud(this)
123 !! Finalise the message passing layer
124 implicit none
125
126 ! Arguments
127 type(duvenaud_msgpass_layer_type), intent(inout) :: this
128 !! Instance of the fully connected layer
129
130 if(allocated(this%input_shape)) deallocate(this%input_shape)
131 if(allocated(this%output_shape)) deallocate(this%output_shape)
132 if(allocated(this%output)) deallocate(this%output)
133
134 end subroutine finalise_duvenaud
135 !###############################################################################
136
137
138 !##############################################################################!
139 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
140 !##############################################################################!
141
142
143 !###############################################################################
144 pure function get_num_params_duvenaud(this) result(num_params)
145 !! Get the number of parameters for the message passing layer
146 !!
147 !! This function calculates the number of parameters for the message passing
148 !! layer.
149 !! This procedure is based on code from the neural-fortran library
150 implicit none
151
152 ! Arguments
153 class(duvenaud_msgpass_layer_type), intent(in) :: this
154 !! Instance of the message passing layer
155 integer :: num_params
156 !! Number of parameters
157
158 num_params = ( this%num_vertex_features(0) + this%num_edge_features(0) ) * &
159 this%num_vertex_features(0) * &
160 ( this%max_vertex_degree - this%min_vertex_degree + 1 ) * &
161 this%num_time_steps + &
162 this%num_vertex_features(0) * this%num_outputs * this%num_time_steps
163
164 end function get_num_params_duvenaud
165 !###############################################################################
166
167
168 !##############################################################################!
169 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
170 !##############################################################################!
171
172
173 !###############################################################################
174 module function layer_setup( &
175 num_vertex_features ,num_edge_features, num_time_steps, &
176 max_vertex_degree, &
177 num_outputs, &
178 min_vertex_degree, &
179 message_activation, &
180 readout_activation, &
181 kernel_initialiser, &
182 verbose &
183 ) result(layer)
184 !! Set up the message passing layer
185 use athena__initialiser, only: initialiser_setup
186 use athena__activation, only: activation_setup
187 implicit none
188
189 ! Arguments
190 integer, dimension(:), intent(in) :: num_vertex_features
191 !! Number of vertex features
192 integer, dimension(:), intent(in) :: num_edge_features
193 !! Number of edge features
194 integer, intent(in) :: num_time_steps
195 !! Number of time steps
196 integer, intent(in) :: max_vertex_degree
197 !! Maximum vertex degree
198 integer, intent(in) :: num_outputs
199 !! Number of outputs
200 integer, optional, intent(in) :: min_vertex_degree
201 !! Minimum vertex degree
202 class(*), optional, intent(in) :: message_activation, &
203 readout_activation
204 !! Message and readout activation functions
205 character(*), optional, intent(in) :: kernel_initialiser
206 !!! Kernel initialiser
207 integer, optional, intent(in) :: verbose
208 !! Verbosity level
209 type(duvenaud_msgpass_layer_type) :: layer
210 !! Instance of the message passing layer
211
212 ! Local variables
213 integer :: verbose_ = 0
214 !! Verbosity level
215 class(base_actv_type), allocatable :: message_activation_ , readout_activation_
216 !! Activation function
217 class(base_init_type), allocatable :: kernel_initialiser_
218 !! Kernel and bias initialisers
219 integer :: min_vertex_degree_ = 1
220 !! Minimum vertex degree
221
222 if(present(verbose)) verbose_ = verbose
223
224
225 !---------------------------------------------------------------------------
226 ! Set activation functions
227 !---------------------------------------------------------------------------
228 if(present(message_activation))then
229 message_activation_ = activation_setup(message_activation)
230 else
231 message_activation_ = activation_setup(default_message_actv_name)
232 end if
233 if(present(readout_activation))then
234 readout_activation_ = activation_setup(readout_activation)
235 else
236 readout_activation_ = activation_setup(default_readout_actv_name)
237 end if
238
239
240 !---------------------------------------------------------------------------
241 ! Set minimum vertex degree
242 !---------------------------------------------------------------------------
243 if(present(min_vertex_degree)) min_vertex_degree_ = min_vertex_degree
244 if(max_vertex_degree.lt.min_vertex_degree_)then
245 write(0,*) "Error: max_vertex_degree < min_vertex_degree"
246 return
247 end if
248
249
250 !---------------------------------------------------------------------------
251 ! Define weights (kernels) and biases initialisers
252 !---------------------------------------------------------------------------
253 if(present(kernel_initialiser))then
254 kernel_initialiser_ = initialiser_setup(kernel_initialiser)
255 end if
256
257
258 !---------------------------------------------------------------------------
259 ! Set hyperparameters
260 !---------------------------------------------------------------------------
261 call layer%set_hyperparams( &
262 num_vertex_features = num_vertex_features, &
263 num_edge_features = num_edge_features, &
264 min_vertex_degree = min_vertex_degree_, &
265 max_vertex_degree = max_vertex_degree, &
266 num_time_steps = num_time_steps, &
267 num_outputs = num_outputs, &
268 message_activation = message_activation_, &
269 readout_activation = readout_activation_, &
270 kernel_initialiser = kernel_initialiser_, &
271 verbose = verbose_ &
272 )
273
274
275 !---------------------------------------------------------------------------
276 ! Initialise layer shape
277 !---------------------------------------------------------------------------
278 call layer%init(input_shape=[ &
279 layer%num_vertex_features(0), &
280 layer%num_edge_features(0) &
281 ])
282
283 end function layer_setup
284 !###############################################################################
285
286
287 !###############################################################################
288 subroutine set_hyperparams_duvenaud( &
289 this, &
290 num_vertex_features, num_edge_features, &
291 min_vertex_degree, &
292 max_vertex_degree, &
293 num_time_steps, &
294 num_outputs, &
295 message_activation, &
296 readout_activation, &
297 kernel_initialiser, &
298 verbose &
299 )
300 !! Set the hyperparameters for the message passing layer
301 use athena__activation, only: activation_setup
302 use athena__initialiser, only: get_default_initialiser, initialiser_setup
303 implicit none
304
305 ! Arguments
306 class(duvenaud_msgpass_layer_type), intent(inout) :: this
307 !! Instance of the message passing layer
308 integer, dimension(:), intent(in) :: num_vertex_features
309 !! Number of vertex features
310 integer, dimension(:), intent(in) :: num_edge_features
311 !! Number of edge features
312 integer, intent(in) :: min_vertex_degree
313 !! Minimum vertex degree
314 integer, intent(in) :: max_vertex_degree
315 !! Maximum vertex degree
316 integer, intent(in) :: num_time_steps
317 !! Number of time steps
318 integer, intent(in) :: num_outputs
319 !! Number of outputs
320 class(base_actv_type), allocatable, intent(in) :: &
321 message_activation, &
322 readout_activation
323 !! Message and readout activation functions
324 class(base_init_type), allocatable, intent(in) :: kernel_initialiser
325 !! Kernel and bias initialisers
326 integer, optional, intent(in) :: verbose
327 !! Verbosity level
328
329 ! Local variables
330 integer :: t
331 !! Loop index
332 character(len=256) :: buffer
333
334
335 this%name = 'duvenaud'
336 this%type = 'msgp'
337 this%input_rank = 2
338 this%output_rank = 1
339 this%min_vertex_degree = min_vertex_degree
340 this%max_vertex_degree = max_vertex_degree
341 this%num_time_steps = num_time_steps
342 this%num_outputs = num_outputs
343 if(allocated(this%num_vertex_features)) &
344 deallocate(this%num_vertex_features)
345 if(allocated(this%num_edge_features)) &
346 deallocate(this%num_edge_features)
347 if(size(num_vertex_features, 1) .eq. 1) then
348 allocate( &
349 this%num_vertex_features(0:num_time_steps), &
350 source = num_vertex_features(1) &
351 )
352 elseif(size(num_vertex_features, 1) .eq. num_time_steps + 1) then
353 allocate( &
354 this%num_vertex_features(0:this%num_time_steps), &
355 source = num_vertex_features &
356 )
357 else
358 write(*,*) "Error: num_vertex_features must be a scalar or a vector of &
359 &length num_time_steps + 1"
360 stop
361 end if
362 if(size(num_edge_features, 1) .eq. 1) then
363 allocate( &
364 this%num_edge_features(0:num_time_steps), &
365 source = num_edge_features(1) &
366 )
367 elseif(size(num_edge_features, 1) .eq. num_time_steps + 1) then
368 allocate( &
369 this%num_edge_features(0:this%num_time_steps), &
370 source = num_edge_features &
371 )
372 else
373 write(*,*) "Error: num_edge_features must be a scalar or a vector of &
374 &length num_time_steps + 1"
375 stop
376 end if
377 this%use_graph_input = .true.
378 this%use_graph_output = .false.
379 if(allocated(this%activation)) deallocate(this%activation)
380 if(allocated(this%activation_readout)) deallocate(this%activation_readout)
381 if(.not.allocated(message_activation))then
382 this%activation = activation_setup(default_message_actv_name)
383 else
384 allocate( this%activation, source=message_activation )
385 end if
386 if(.not.allocated(readout_activation))then
387 this%activation_readout = activation_setup(default_readout_actv_name)
388 else
389 allocate(this%activation_readout, source=readout_activation)
390 end if
391 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
392 if(.not.allocated(kernel_initialiser))then
393 buffer = get_default_initialiser(this%activation%name)
394 this%kernel_init = initialiser_setup(buffer)
395 else
396 allocate(this%kernel_init, source=kernel_initialiser)
397 end if
398 if(present(verbose))then
399 if(abs(verbose).gt.0)then
400 write(*,'("DUVENAUD message activation function: ",A)') &
401 trim(this%activation%name)
402 write(*,'("DUVENAUD readout activation function: ",A)') &
403 trim(this%activation_readout%name)
404 write(*,'("DUVENAUD kernel initialiser: ",A)') &
405 trim(this%kernel_init%name)
406 end if
407 end if
408
409 if(allocated(this%num_params_msg)) deallocate(this%num_params_msg)
410 allocate(this%num_params_msg(1:this%num_time_steps))
411 do t = 1, this%num_time_steps
412 this%num_params_msg(t) = &
413 ( this%num_vertex_features(t-1) + this%num_edge_features(0) ) * &
414 this%num_vertex_features(t) * &
415 ( this%max_vertex_degree - this%min_vertex_degree + 1 )
416 end do
417 this%num_params_readout = &
418 sum( this%num_vertex_features * this%num_outputs )
419
420 if(allocated(this%input_shape)) deallocate(this%input_shape)
421 if(allocated(this%output_shape)) deallocate(this%output_shape)
422
423 end subroutine set_hyperparams_duvenaud
424 !###############################################################################
425
426
427 !###############################################################################
428 subroutine init_duvenaud(this, input_shape, verbose)
429 !! Initialise the message passing layer
430 use athena__initialiser, only: initialiser_setup
431 implicit none
432
433 ! Arguments
434 class(duvenaud_msgpass_layer_type), intent(inout) :: this
435 !! Instance of the fully connected layer
436 integer, dimension(:), intent(in) :: input_shape
437 !! Input shape
438 integer, optional, intent(in) :: verbose
439 !! Verbosity level
440
441 ! Local variables
442 integer :: t
443 !! Loop index
444 integer :: verbose_ = 0
445 !! Verbosity level
446
447
448 !---------------------------------------------------------------------------
449 ! Initialise optional arguments
450 !---------------------------------------------------------------------------
451 if(present(verbose)) verbose_ = verbose
452
453
454 !---------------------------------------------------------------------------
455 ! Initialise number of inputs
456 !---------------------------------------------------------------------------
457 if(.not.allocated(this%input_shape)) call this%set_shape([input_shape])
458 this%output_shape = [this%num_outputs]
459 this%num_params = this%get_num_params()
460
461
462 !---------------------------------------------------------------------------
463 ! Allocate weight, weight steps (velocities), output, and activation
464 !---------------------------------------------------------------------------
465 allocate(this%weight_shape(3,2*this%num_time_steps))
466 allocate(this%params(this%num_time_steps*2))
467 do t = 1, this%num_time_steps
468 this%weight_shape(:,t) = [ &
469 this%num_vertex_features(t), &
470 this%num_vertex_features(t-1) + this%num_edge_features(0), &
471 this%max_vertex_degree - this%min_vertex_degree + 1 &
472 ]
473 this%weight_shape(:,t+this%num_time_steps) = &
474 [ this%num_outputs, this%num_vertex_features(t), 1 ]
475 call this%params(t)%allocate( [ this%weight_shape(:,t), 1 ] )
476 call this%params(t+this%num_time_steps)%allocate( &
477 [ this%weight_shape(:2,t+this%num_time_steps), 1 ] &
478 )
479 call this%params(t)%set_requires_grad(.true.)
480 this%params(t)%fix_pointer = .true.
481 this%params(t)%is_temporary = .false.
482 this%params(t)%is_sample_dependent = .false.
483 this%params(t)%indices = [ this%min_vertex_degree, this%max_vertex_degree ]
484 call this%params(t+this%num_time_steps)%set_requires_grad(.true.)
485 this%params(t+this%num_time_steps)%fix_pointer = .true.
486 this%params(t+this%num_time_steps)%is_temporary = .false.
487 this%params(t+this%num_time_steps)%is_sample_dependent = .false.
488 end do
489
490
491 !---------------------------------------------------------------------------
492 ! Initialise weights (kernels)
493 !---------------------------------------------------------------------------
494 do t = 1, this%num_time_steps, 1
495 call this%kernel_init%initialise( &
496 this%params(t)%val(:,1), &
497 fan_in = this%num_vertex_features(t-1) + this%num_edge_features(0), &
498 fan_out = this%num_vertex_features(t), &
499 spacing = [ this%num_vertex_features(t-1) ] &
500 )
501 call this%kernel_init%initialise( &
502 this%params(t+this%num_time_steps)%val(:,1), &
503 fan_in = sum(this%num_vertex_features), &
504 fan_out = this%num_outputs, &
505 spacing = this%num_vertex_features &
506 )
507 end do
508
509
510 !---------------------------------------------------------------------------
511 ! Allocate arrays
512 !---------------------------------------------------------------------------
513 if(allocated(this%output)) deallocate(this%output)
514 allocate(this%output(1,1))
515 if(allocated(this%z)) deallocate(this%z)
516
517 end subroutine init_duvenaud
518 !###############################################################################
519
520
521 !##############################################################################!
522 subroutine set_graph_duvenaud(this, graph)
523 !! Set the graph structure of the input data
524 implicit none
525
526 ! Arguments
527 class(duvenaud_msgpass_layer_type), intent(inout) :: this
528 !! Instance of the layer
529 type(graph_type), dimension(:), intent(in) :: graph
530 !! Graph structure of input data
531
532 ! Local variables
533 integer :: s, t
534 !! Loop indices
535
536 if(allocated(this%graph))then
537 if(size(this%graph).ne.size(graph))then
538 deallocate(this%graph)
539 allocate(this%graph(size(graph)))
540 end if
541 else
542 allocate(this%graph(size(graph)))
543 end if
544 do s = 1, size(graph)
545 this%graph(s)%adj_ia = graph(s)%adj_ia
546 this%graph(s)%adj_ja = graph(s)%adj_ja
547 this%graph(s)%edge_weights = graph(s)%edge_weights
548 this%graph(s)%num_edges = graph(s)%num_edges
549 this%graph(s)%num_vertices = graph(s)%num_vertices
550 if(any(this%graph(s)%adj_ja(1,:).gt.this%graph(s)%num_vertices))then
551 write(*,*) "Error: graph adjacency matrix has indices greater than &
552 &the number of vertices", s, &
553 this%graph(s)%num_vertices
554 write(*,*) "Adjacency matrix indices: ", this%graph(s)%adj_ja
555 stop
556 end if
557 end do
558
559 end subroutine set_graph_duvenaud
560 !##############################################################################!
561
562
563 !##############################################################################!
564 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
565 !##############################################################################!
566
567
568 !###############################################################################
569 subroutine print_to_unit_duvenaud(this, unit)
570 !! Print kipf message passing layer to unit
571 use coreutils, only: to_upper
572 implicit none
573
574 ! Arguments
575 class(duvenaud_msgpass_layer_type), intent(in) :: this
576 !! Instance of the message passing layer
577 integer, intent(in) :: unit
578 !! Filename
579
580 ! Local variables
581 integer :: t
582 !! Loop index
583 character(100) :: fmt
584 !! Format string
585
586
587 ! Write initial parameters
588 !---------------------------------------------------------------------------
589 write(unit,'(3X,"NUM_TIME_STEPS = ",I0)') this%num_time_steps
590 write(fmt,'("(3X,""NUM_VERTEX_FEATURES ="",",I0,"(1X,I0))")') &
591 this%num_time_steps + 1
592 write(unit,fmt) this%num_vertex_features
593 write(fmt,'("(3X,""NUM_EDGE_FEATURES ="",",I0,"(1X,I0))")') &
594 this%num_time_steps + 1
595 write(unit,fmt) this%num_edge_features
596
597 if(this%activation%name .ne. 'none')then
598 call this%activation%print_to_unit(unit, identifier='MESSAGE')
599 end if
600 if(this%activation_readout%name .ne. 'none')then
601 call this%activation_readout%print_to_unit(unit, identifier='READOUT')
602 end if
603
604
605 ! Write learned parameters
606 !---------------------------------------------------------------------------
607 write(unit,'("WEIGHTS")')
608 do t = 1, this%num_time_steps, 1
609 write(unit,'(5(E16.8E2))') this%params(t)%val(:,1)
610 end do
611 do t = 1, this%num_time_steps, 1
612 write(unit,'(5(E16.8E2))') this%params(t+this%num_time_steps)%val(:,1)
613 end do
614 write(unit,'("END WEIGHTS")')
615
616 end subroutine print_to_unit_duvenaud
617 !###############################################################################
618
619
620 !###############################################################################
621 subroutine read_duvenaud(this, unit, verbose)
622 !! Read the message passing layer
623 implicit none
624
625 ! Arguments
626 class(duvenaud_msgpass_layer_type), intent(inout) :: this
627 !! Instance of the message passing layer
628 integer, intent(in) :: unit
629 !! Unit to read from
630 integer, optional, intent(in) :: verbose
631 !! Verbosity level
632 end subroutine read_duvenaud
633 !###############################################################################
634
635
636 !###############################################################################
637 function read_duvenaud_msgpass_layer(unit, verbose) result(layer)
638 !! Read duvenaud message passing layer from file and return layer
639 implicit none
640
641 ! Arguments
642 integer, intent(in) :: unit
643 !! Unit number
644 integer, optional, intent(in) :: verbose
645 !! Verbosity level
646 class(base_layer_type), allocatable :: layer
647 !! Instance of the message passing layer
648
649 ! Local variables
650 integer :: verbose_ = 0
651 !! Verbosity level
652
653 if(present(verbose)) verbose_ = verbose
654 allocate(layer, source = duvenaud_msgpass_layer_type( &
655 num_vertex_features = [ 0 ], &
656 num_edge_features = [ 0 ], &
657 num_time_steps = 1, &
658 max_vertex_degree = 1, &
659 num_outputs = 1 &
660 ))
661 call layer%read(unit, verbose=verbose_)
662
663 end function read_duvenaud_msgpass_layer
664 !###############################################################################
665
666
667 !##############################################################################!
668 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
669 !##############################################################################!
670
671
672 !##############################################################################!
673 subroutine update_message_duvenaud(this, input)
674 !! Update the message
675 implicit none
676
677 ! Arguments
678 class(duvenaud_msgpass_layer_type), intent(inout), target :: this
679 !! Instance of the message passing layer
680 class(array_type), dimension(:,:), intent(in), target :: input
681 !! Input to the message passing layer
682
683 ! Local variables
684 integer :: s, t
685 !! Batch index, time step
686 logical :: has_activation
687 type(array_type), pointer :: ptr1, ptr2, ptr3, ptr_edge, ptr_params
688 !! Pointers to arrays
689
690
691 if(allocated(this%z))then
692 if(size(this%z,2).ne.size(input,2))then
693 deallocate(this%z)
694 allocate(this%z(this%num_time_steps,size(input,2)))
695 end if
696 else
697 allocate(this%z(this%num_time_steps,size(input,2)))
698 end if
699
700
701 if(.not.allocated(this%activation))then
702 has_activation = .false.
703 else
704 if(trim(this%activation%name).eq."none")then
705 has_activation = .true.
706 else
707 has_activation = .true.
708 end if
709 end if
710 do s = 1, size(input,2)
711 ptr1 => input(1,s)
712 ptr_edge => input(2,s)
713 do t = 1, this%num_time_steps
714 ptr2 => duvenaud_propagate( &
715 ptr1, ptr_edge, &
716 this%graph(s)%adj_ia, this%graph(s)%adj_ja &
717 )
718
719 ptr_params => this%params(t)
720 ptr3 => duvenaud_update( &
721 ptr2, ptr_params, &
722 this%graph(s)%adj_ia, &
723 this%min_vertex_degree, this%max_vertex_degree &
724 )
725 if(has_activation)then
726 ptr3 => this%activation%apply( ptr3 )
727 end if
728 call this%z(t,s)%zero_grad()
729 call this%z(t,s)%assign_and_deallocate_source(ptr3)
730 this%z(t,s)%is_temporary = .false.
731 ptr1 => this%z(t,s)
732 end do
733 end do
734
735 end subroutine update_message_duvenaud
736 !###############################################################################
737
738
739 !##############################################################################!
740 subroutine update_readout_duvenaud(this)
741 !! Update the readout
742 implicit none
743
744 ! Arguments
745 class(duvenaud_msgpass_layer_type), intent(inout), target :: this
746 !! Instance of the message passing layer
747
748 ! Local variables
749 integer :: s, t, batch_size
750 !! Loop indices
751 type(array_type), pointer :: ptr1, ptr2, ptr3, ptr_params, ptr_z
752
753
754 batch_size = size(this%z,2)
755 call this%output(1,1)%zero_grad()
756 do t = 1, this%num_time_steps, 1
757 do s = 1, batch_size, 1
758 ptr_params => this%params(t+this%num_time_steps)
759 ptr_z => this%z(t,s)
760 ptr1 => matmul( &
761 ptr_params, &
762 ptr_z &
763 )
764 ptr2 => this%activation_readout%apply( ptr1 )
765 if(t.eq.1.and.s.eq.1)then
766 ptr3 => &
767 sum( ptr2, dim = 2, new_dim_index=s, new_dim_size=batch_size )
768 else
769 ptr3 => ptr3 + &
770 sum( ptr2, dim = 2, new_dim_index=s, new_dim_size=batch_size )
771 end if
772 end do
773 end do
774 call this%output(1,1)%assign_and_deallocate_source(ptr3)
775 this%output(1,1)%is_temporary = .false.
776
777 end subroutine update_readout_duvenaud
778 !###############################################################################
779
780 end module athena__duvenaud_msgpass_layer
781