GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_duvenaud.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_msgpass_duvenaud
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function duvenaud_propagate( &
8 vertex_features, edge_features, adj_ia, adj_ja &
9 ) result(c)
10 !! Propagate values from one autodiff array to another
11 implicit none
12 class(array_type), intent(in), target :: vertex_features, edge_features
13 integer, dimension(:), intent(in) :: adj_ia
14 integer, dimension(:,:), intent(in) :: adj_ja
15 type(array_type), pointer :: c
16
17 integer :: v, w
18
19 c => vertex_features%create_result( &
20 array_shape = [ &
21 size(vertex_features%val,1) + size(edge_features%val,1), &
22 size(vertex_features%val,2) &
23 ] &
24 )
25 ! propagate 1D array by using shape to swap dimensions
26 do concurrent(v=1:size(vertex_features%val,2))
27 c%val(:,v) = 0.0_real32
28 do w = adj_ia(v), adj_ia(v+1)-1
29 c%val(:,v) = c%val(:,v) + [ &
30 vertex_features%val(:, adj_ja(1, w)), &
31 edge_features%val(:, adj_ja(2, w)) &
32 ]
33 end do
34 end do
35
36 c%indices = adj_ia
37 c%adj_ja = adj_ja
38 c%get_partial_left => get_partial_duvenaud_propagate_left
39 c%get_partial_right => get_partial_duvenaud_propagate_right
40 c%get_partial_left_val => get_partial_duvenaud_propagate_left_val
41 c%get_partial_right_val => get_partial_duvenaud_propagate_right_val
42 if(vertex_features%requires_grad .or. edge_features%requires_grad) then
43 c%requires_grad = .true.
44 c%is_forward = vertex_features%is_forward .or. edge_features%is_forward
45 c%operation = 'duvenaud_propagate'
46 c%left_operand => vertex_features
47 c%right_operand => edge_features
48 c%owns_left_operand = vertex_features%is_temporary
49 c%owns_right_operand = edge_features%is_temporary
50 end if
51 end function duvenaud_propagate
52 !-------------------------------------------------------------------------------
53 function get_partial_duvenaud_propagate_left(this, upstream_grad) result(output)
54 implicit none
55 class(array_type), intent(inout) :: this
56 type(array_type), intent(in) :: upstream_grad
57 type(array_type) :: output
58
59 logical :: right_is_temporary_local
60 type(array_type), pointer :: ptr
61
62 right_is_temporary_local = this%right_operand%is_temporary
63 this%right_operand%is_temporary = .false.
64 ptr => duvenaud_propagate( upstream_grad, this%right_operand, &
65 this%indices, this%adj_ja )
66 this%right_operand%is_temporary = right_is_temporary_local
67 call output%assign_and_deallocate_source(ptr)
68
69 end function get_partial_duvenaud_propagate_left
70 !-------------------------------------------------------------------------------
71 function get_partial_duvenaud_propagate_right(this, upstream_grad) result(output)
72 implicit none
73 class(array_type), intent(inout) :: this
74 type(array_type), intent(in) :: upstream_grad
75 type(array_type) :: output
76
77 logical :: left_is_temporary_local
78 type(array_type), pointer :: ptr
79
80 left_is_temporary_local = this%left_operand%is_temporary
81 this%left_operand%is_temporary = .false.
82 ptr => duvenaud_propagate( this%left_operand, upstream_grad, &
83 this%indices, this%adj_ja )
84 this%left_operand%is_temporary = left_is_temporary_local
85 call output%assign_and_deallocate_source(ptr)
86
87 end function get_partial_duvenaud_propagate_right
88 !-------------------------------------------------------------------------------
89 pure subroutine get_partial_duvenaud_propagate_left_val( &
90 this, upstream_grad, output &
91 )
92 implicit none
93 class(array_type), intent(in) :: this
94 real(real32), dimension(:,:), intent(in) :: upstream_grad
95 real(real32), dimension(:,:), intent(out) :: output
96
97 integer :: v, w, num_features, num_elements
98
99 num_features = size(this%left_operand%val,1)
100 num_elements = size(this%left_operand%val,2)
101 output = 0._real32
102 do concurrent(v=1:num_elements)
103 do w = this%indices(v), this%indices(v+1)-1
104 output(:,this%adj_ja(1,w)) = output(:,this%adj_ja(1,w)) + &
105 [ upstream_grad(1:num_features, v) ]
106 end do
107 end do
108 end subroutine get_partial_duvenaud_propagate_left_val
109 !-------------------------------------------------------------------------------
110 pure subroutine get_partial_duvenaud_propagate_right_val( &
111 this, upstream_grad, output &
112 )
113 implicit none
114 class(array_type), intent(in) :: this
115 real(real32), dimension(:,:), intent(in) :: upstream_grad
116 real(real32), dimension(:,:), intent(out) :: output
117
118 integer :: v, w, num_features, num_elements
119
120 num_features = size(this%left_operand%val,1)
121 num_elements = size(this%left_operand%val,2)
122 output = 0._real32
123 do concurrent(v=1:num_elements)
124 do w = this%indices(v), this%indices(v+1)-1
125 output(:,this%adj_ja(2,w)) = output(:,this%adj_ja(2,w)) + &
126 [ upstream_grad(num_features+1:, v) ]
127 end do
128 end do
129 end subroutine get_partial_duvenaud_propagate_right_val
130 !###############################################################################
131
132
133 !###############################################################################
134 module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c)
135 !! Update the message passing layer
136 implicit none
137 class(array_type), intent(in), target :: a
138 class(array_type), intent(in), target :: weight
139 ! real(real32), dimension(:,:,:), intent(in) :: weight
140 integer, dimension(:), intent(in) :: adj_ia
141 integer, intent(in) :: min_degree, max_degree
142 type(array_type), pointer :: c
143 type(array_type), pointer :: weight_array
144
145 integer :: v, i, d
146 integer :: interval
147 real(real32), pointer :: w_ptr(:,:)
148
149 c => a%create_result(array_shape=[weight%shape(1), size(a%val,2)])
150 interval = weight%shape(1) * weight%shape(2)
151 do v = 1, size(a%val,2)
152 d = max( min_degree, min( adj_ia(v+1) - adj_ia(v), max_degree ) ) - &
153 min_degree + 1
154 w_ptr(1:weight%shape(1), 1:weight%shape(2)) => &
155 weight%val(interval*(d-1)+1:interval*d,1)
156 c%val(:,v) = matmul(w_ptr, a%val(:,v) / real(d, real32))
157 end do
158 c%indices = adj_ia
159
160 c%get_partial_left => get_partial_duvenaud_update_weight
161 c%get_partial_right => get_partial_duvenaud_update
162 c%get_partial_left_val => get_partial_duvenaud_update_weight_val
163 c%get_partial_right_val => get_partial_duvenaud_update_val
164 if(a%requires_grad .or. weight%requires_grad) then
165 c%requires_grad = .true.
166 c%is_forward = a%is_forward .or. weight%is_forward
167 c%operation = 'duvenaud_update'
168 c%right_operand => a
169 c%left_operand => weight
170 c%owns_right_operand = a%is_temporary
171 c%owns_left_operand = weight%is_temporary
172 end if
173
174 end function duvenaud_update
175 !-------------------------------------------------------------------------------
176 function get_partial_duvenaud_update(this, upstream_grad) result(output)
177 class(array_type), intent(inout) :: this
178 type(array_type), intent(in) :: upstream_grad
179 type(array_type) :: output
180 logical :: left_is_temporary_local
181 type(array_type), pointer :: ptr
182
183 left_is_temporary_local = this%left_operand%is_temporary
184 this%left_operand%is_temporary = .false.
185 ptr => duvenaud_update( upstream_grad, this%left_operand, &
186 this%indices, this%left_operand%indices(1), this%left_operand%indices(2) )
187 this%left_operand%is_temporary = left_is_temporary_local
188 call output%assign_and_deallocate_source(ptr)
189
190 end function get_partial_duvenaud_update
191 !-------------------------------------------------------------------------------
192 function get_partial_duvenaud_update_weight(this, upstream_grad) result(output)
193 class(array_type), intent(inout) :: this
194 type(array_type), intent(in) :: upstream_grad
195 type(array_type) :: output
196 logical :: right_is_temporary_local
197 type(array_type), pointer :: ptr
198
199 right_is_temporary_local = this%right_operand%is_temporary
200 this%right_operand%is_temporary = .false.
201 ptr => duvenaud_update( this%right_operand, upstream_grad, &
202 this%indices, this%left_operand%indices(1), this%left_operand%indices(2) )
203 this%right_operand%is_temporary = right_is_temporary_local
204 call output%assign_and_deallocate_source(ptr)
205
206 end function get_partial_duvenaud_update_weight
207 !-------------------------------------------------------------------------------
208 pure subroutine get_partial_duvenaud_update_val( &
209 this, upstream_grad, output &
210 )
211 implicit none
212 class(array_type), intent(in) :: this
213 real(real32), dimension(:,:), intent(in) :: upstream_grad
214 real(real32), dimension(:,:), intent(out) :: output
215
216 integer :: v, d
217 integer :: interval, num_output_features, num_input_features
218 integer :: min_degree, max_degree
219 real(real32), dimension(size(upstream_grad,1), this%right_operand%shape(1)) :: tmp
220
221 output = 0._real32
222 num_output_features = size(upstream_grad,1)
223 num_input_features = this%right_operand%shape(1)
224 interval = num_output_features * num_input_features
225 min_degree = this%left_operand%indices(1)
226 max_degree = this%left_operand%indices(2)
227 do concurrent(v=1:size(upstream_grad,2))
228 d = max( &
229 min_degree, &
230 min(this%indices(v+1) - this%indices(v), max_degree ) &
231 ) - min_degree + 1
232 tmp = reshape(this%left_operand%val((d-1)*interval+1:d*interval,1), &
233 [num_output_features, num_input_features] )
234 output(:,v) = matmul(upstream_grad(:,v), tmp)
235 end do
236
237 end subroutine get_partial_duvenaud_update_val
238 !-------------------------------------------------------------------------------
239 pure subroutine get_partial_duvenaud_update_weight_val( &
240 this, upstream_grad, output &
241 )
242 implicit none
243 class(array_type), intent(in) :: this
244 real(real32), dimension(:,:), intent(in) :: upstream_grad
245 real(real32), dimension(:,:), intent(out) :: output
246
247 integer :: v, i, j, d
248 integer :: interval, num_output_features, num_input_features
249 integer :: min_degree, max_degree
250
251 output = 0._real32
252 num_output_features = size(upstream_grad,1)
253 num_input_features = this%right_operand%shape(1)
254 interval = num_output_features * num_input_features
255 min_degree = this%left_operand%indices(1)
256 max_degree = this%left_operand%indices(2)
257 do concurrent(v=1:size(upstream_grad,2))
258 d = ( max( &
259 min_degree, &
260 min(this%indices(v+1) - this%indices(v), max_degree ) &
261 ) - min_degree ) * interval
262 do concurrent(i = 1:num_output_features, j=1:num_input_features)
263 output(d+i+num_output_features*(j-1),1) = &
264 output(d+i+num_output_features*(j-1),1) + &
265 upstream_grad(i,v) * this%right_operand%val(j,v)
266 end do
267 end do
268
269 end subroutine get_partial_duvenaud_update_weight_val
270 !###############################################################################
271
272 end submodule athena__diffstruc_extd_submodule_msgpass_duvenaud
273