| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule(athena__msgpass_layer) athena__msgpass_layer_submodule | ||
| 2 | !! Submodule containing implementations for a message passing layer | ||
| 3 | implicit none | ||
| 4 | |||
| 5 | |||
| 6 | |||
| 7 | contains | ||
| 8 | |||
| 9 | !############################################################################### | ||
| 10 | ✗ | pure module function get_num_params_msgpass(this) result(num_params) | |
| 11 | !! Get the number of learnable parameters in the layer | ||
| 12 | implicit none | ||
| 13 | |||
| 14 | ! Arguments | ||
| 15 | class(msgpass_layer_type), intent(in) :: this | ||
| 16 | !! Instance of the layer type | ||
| 17 | integer :: num_params | ||
| 18 | !! Number of learnable parameters | ||
| 19 | |||
| 20 | ! Local variables | ||
| 21 | integer :: t | ||
| 22 | !! Time step | ||
| 23 | |||
| 24 | ✗ | num_params = sum(this%num_params_msg) + this%num_params_readout | |
| 25 | ✗ | end function get_num_params_msgpass | |
| 26 | !############################################################################### | ||
| 27 | |||
| 28 | |||
| 29 | !##############################################################################! | ||
| 30 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 31 | !##############################################################################! | ||
| 32 | |||
| 33 | |||
| 34 | !############################################################################### | ||
| 35 | ✗ | module function layer_setup( & | |
| 36 | num_features, num_time_steps, & | ||
| 37 | verbose & | ||
| 38 | ✗ | ) result(layer) | |
| 39 | !! Procedure to set up the layer | ||
| 40 | implicit none | ||
| 41 | |||
| 42 | ! Arguments | ||
| 43 | integer, dimension(2), intent(in) :: num_features | ||
| 44 | !! Number of features | ||
| 45 | integer, intent(in) :: num_time_steps | ||
| 46 | !! Number of time steps | ||
| 47 | integer, optional, intent(in) :: verbose | ||
| 48 | !! Verbosity level | ||
| 49 | |||
| 50 | class(msgpass_layer_type), allocatable :: layer | ||
| 51 | !! Instance of the layer type | ||
| 52 | |||
| 53 | ! Local variables | ||
| 54 | integer :: verbose_ = 0 | ||
| 55 | !! Verbosity level | ||
| 56 | |||
| 57 | |||
| 58 | ✗ | end function layer_setup | |
| 59 | !############################################################################### | ||
| 60 | |||
| 61 | |||
| 62 | !############################################################################### | ||
| 63 | ! module subroutine set_hyperparams_msgpass( & | ||
| 64 | ! this, & | ||
| 65 | ! num_features, num_time_steps, num_outputs, verbose & | ||
| 66 | ! ) | ||
| 67 | ! !! Set the hyperparameters for the layer | ||
| 68 | ! implicit none | ||
| 69 | |||
| 70 | ! ! Arguments | ||
| 71 | ! class(msgpass_layer_type), intent(inout) :: this | ||
| 72 | ! !! Instance of the layer type | ||
| 73 | ! integer, dimension(2), intent(in) :: num_features | ||
| 74 | ! !! Number of features | ||
| 75 | ! integer, intent(in) :: num_time_steps | ||
| 76 | ! !! Number of time steps | ||
| 77 | ! integer, intent(in) :: num_outputs | ||
| 78 | ! !! Number of output features | ||
| 79 | ! integer, optional, intent(in) :: verbose | ||
| 80 | ! !! Verbosity level | ||
| 81 | |||
| 82 | |||
| 83 | ! this%name = 'msgpass' | ||
| 84 | ! this%type = 'msgp' | ||
| 85 | ! this%input_rank = 1 | ||
| 86 | ! this%num_outputs = num_outputs | ||
| 87 | ! this%num_time_steps = num_time_steps | ||
| 88 | ! this%num_vertex_features = num_features(1) | ||
| 89 | ! this%num_edge_features = num_features(2) | ||
| 90 | |||
| 91 | ! end subroutine set_hyperparams_msgpass | ||
| 92 | !############################################################################### | ||
| 93 | |||
| 94 | |||
| 95 | !############################################################################### | ||
| 96 | ✗ | module subroutine init_msgpass(this, input_shape, verbose) | |
| 97 | !! Initialise the layer | ||
| 98 | implicit none | ||
| 99 | |||
| 100 | ! Arguments | ||
| 101 | class(msgpass_layer_type), intent(inout) :: this | ||
| 102 | !! Instance of the layer type | ||
| 103 | integer, dimension(:), intent(in) :: input_shape | ||
| 104 | !! Input shape | ||
| 105 | integer, optional, intent(in) :: verbose | ||
| 106 | !! Verbosity level | ||
| 107 | |||
| 108 | ! Local variables | ||
| 109 | integer :: t | ||
| 110 | !! Time step | ||
| 111 | integer :: verbose_ = 0 | ||
| 112 | !! Verbosity level | ||
| 113 | integer :: num_params_message | ||
| 114 | !! Number of parameters in the message | ||
| 115 | integer :: num_params_readout | ||
| 116 | !! Number of parameters in the readout | ||
| 117 | |||
| 118 | |||
| 119 | !--------------------------------------------------------------------------- | ||
| 120 | ! Initialise optional arguments | ||
| 121 | !--------------------------------------------------------------------------- | ||
| 122 | ✗ | if(present(verbose)) verbose_ = verbose | |
| 123 | |||
| 124 | |||
| 125 | !--------------------------------------------------------------------------- | ||
| 126 | ! Initialise number of inputs | ||
| 127 | !--------------------------------------------------------------------------- | ||
| 128 | ✗ | this%input_shape = [ 1 ] !input_shape | |
| 129 | ! this%output_shape = this%num_outputs | ||
| 130 | !if(.not.allocated(this%input_shape)) call this%set_shape(input_shape) | ||
| 131 | ✗ | this%num_params = this%get_num_params() | |
| 132 | |||
| 133 | |||
| 134 | !--------------------------------------------------------------------------- | ||
| 135 | ! Allocate arrays | ||
| 136 | !--------------------------------------------------------------------------- | ||
| 137 | ✗ | if(allocated(this%output)) deallocate(this%output) | |
| 138 | |||
| 139 | ✗ | end subroutine init_msgpass | |
| 140 | !############################################################################### | ||
| 141 | |||
| 142 | |||
| 143 | !############################################################################### | ||
| 144 |
1/2✓ Branch 0 taken 20 times.
✗ Branch 1 not taken.
|
20 | module subroutine set_graph_msgpass(this, graph) |
| 145 | !! Set the graph structure of the input data | ||
| 146 | implicit none | ||
| 147 | |||
| 148 | ! Arguments | ||
| 149 | class(msgpass_layer_type), intent(inout) :: this | ||
| 150 | !! Instance of the layer | ||
| 151 | type(graph_type), dimension(:), intent(in) :: graph | ||
| 152 | !! Graph structure of input data | ||
| 153 | |||
| 154 | ! Local variables | ||
| 155 | integer :: s | ||
| 156 | !! Loop indices | ||
| 157 | |||
| 158 |
2/2✓ Branch 0 taken 5 times.
✓ Branch 1 taken 15 times.
|
20 | if(allocated(this%graph))then |
| 159 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
|
5 | if(size(this%graph).ne.size(graph))then |
| 160 | ✗ | deallocate(this%graph) | |
| 161 | ✗ | allocate(this%graph(size(graph))) | |
| 162 | end if | ||
| 163 | else | ||
| 164 |
27/60✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 15 times.
✓ Branch 9 taken 15 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 15 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 15 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 15 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 15 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 15 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 15 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 15 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 15 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 15 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 15 times.
✓ Branch 38 taken 15 times.
✓ Branch 39 taken 15 times.
✓ Branch 40 taken 15 times.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✓ Branch 43 taken 15 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 15 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 15 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 15 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 15 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 15 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 15 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 15 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 15 times.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✗ Branch 64 not taken.
✓ Branch 65 taken 15 times.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✗ Branch 68 not taken.
✗ Branch 69 not taken.
|
30 | allocate(this%graph(size(graph))) |
| 165 | end if | ||
| 166 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✓ Branch 9 taken 20 times.
✓ Branch 10 taken 20 times.
|
40 | do s = 1, size(graph) |
| 167 |
16/28✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 20 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 20 times.
✓ Branch 24 taken 5 times.
✓ Branch 25 taken 15 times.
✓ Branch 26 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 15 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 15 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 15 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 116 times.
✓ Branch 35 taken 20 times.
|
136 | this%graph(s)%adj_ia = graph(s)%adj_ia |
| 168 |
25/44✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 20 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 20 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 20 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 20 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 20 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 20 times.
✓ Branch 36 taken 5 times.
✓ Branch 37 taken 15 times.
✓ Branch 38 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 15 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 15 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 15 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 15 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 15 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 242 times.
✓ Branch 53 taken 20 times.
✓ Branch 54 taken 484 times.
✓ Branch 55 taken 242 times.
|
746 | this%graph(s)%adj_ja = graph(s)%adj_ja |
| 169 |
16/28✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 20 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 20 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 20 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 20 times.
✓ Branch 24 taken 5 times.
✓ Branch 25 taken 15 times.
✓ Branch 26 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 15 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 15 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 15 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 121 times.
✓ Branch 35 taken 20 times.
|
141 | this%graph(s)%edge_weights = graph(s)%edge_weights |
| 170 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
|
20 | this%graph(s)%num_edges = graph(s)%num_edges |
| 171 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
|
40 | this%graph(s)%num_vertices = graph(s)%num_vertices |
| 172 | end do | ||
| 173 | |||
| 174 | 20 | end subroutine set_graph_msgpass | |
| 175 | !############################################################################### | ||
| 176 | |||
| 177 | |||
| 178 | !##############################################################################! | ||
| 179 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 180 | !##############################################################################! | ||
| 181 | |||
| 182 | |||
| 183 | !############################################################################### | ||
| 184 |
1/2✓ Branch 0 taken 30 times.
✗ Branch 1 not taken.
|
30 | module subroutine forward_msgpass(this, input) |
| 185 | !! Forward propagation for the layer | ||
| 186 | implicit none | ||
| 187 | |||
| 188 | ! Arguments | ||
| 189 | class(msgpass_layer_type), intent(inout) :: this | ||
| 190 | !! Instance of the layer type | ||
| 191 | class(array_type), dimension(:,:), intent(in) :: input | ||
| 192 | !! Input data (i.e. vertex and edge features) | ||
| 193 | |||
| 194 | |||
| 195 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 30 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 30 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 30 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 30 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 30 times.
|
30 | call this%update_message(input) |
| 196 | 30 | call this%update_readout() | |
| 197 | |||
| 198 | 30 | end subroutine forward_msgpass | |
| 199 | !############################################################################### | ||
| 200 | |||
| 201 | end submodule athena__msgpass_layer_submodule | ||
| 202 |