| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_msgpass_duvenaud | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | 12 | module function duvenaud_propagate( & | |
| 8 |
2/4✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
12 | vertex_features, edge_features, adj_ia, adj_ja & |
| 9 | ) result(c) | ||
| 10 | !! Propagate values from one autodiff array to another | ||
| 11 | implicit none | ||
| 12 | |||
| 13 | ! Arguments | ||
| 14 | class(array_type), intent(in), target :: vertex_features, edge_features | ||
| 15 | !! Vertex and edge feature tensors | ||
| 16 | integer, dimension(:), intent(in) :: adj_ia | ||
| 17 | !! CSR row pointers | ||
| 18 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 19 | !! CSR neighbour and edge lookup indices | ||
| 20 | type(array_type), pointer :: c | ||
| 21 | !! Propagated concatenated feature tensor | ||
| 22 | |||
| 23 | ! Local variables | ||
| 24 | integer :: v, w | ||
| 25 | !! Vertex and adjacency traversal indices | ||
| 26 | |||
| 27 | c => vertex_features%create_result( & | ||
| 28 | array_shape = [ & | ||
| 29 | size(vertex_features%val,1) + size(edge_features%val,1), & | ||
| 30 | size(vertex_features%val,2) & | ||
| 31 | ] & | ||
| 32 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
|
36 | ) |
| 33 | ! propagate 1D array by using shape to swap dimensions | ||
| 34 | 12 | do concurrent(v=1:size(vertex_features%val,2)) | |
| 35 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 54 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 54 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 54 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 54 times.
✓ Branch 18 taken 372 times.
✓ Branch 19 taken 54 times.
|
426 | c%val(:,v) = 0.0_real32 |
| 36 |
8/12✓ Branch 0 taken 54 times.
✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 54 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 54 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 54 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 54 times.
✓ Branch 14 taken 132 times.
✓ Branch 15 taken 54 times.
|
252 | do w = adj_ia(v), adj_ia(v+1)-1 |
| 37 | ✗ | c%val(:,v) = c%val(:,v) + [ & | |
| 38 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 132 times.
|
132 | vertex_features%val(:, adj_ja(1, w)), & |
| 39 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 132 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 132 times.
|
132 | edge_features%val(:, adj_ja(2, w)) & |
| 40 |
34/62✗ Branch 0 not taken.
✓ Branch 1 taken 132 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 132 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 132 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 132 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 132 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 132 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 132 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 132 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 132 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 132 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 132 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 132 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 132 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 132 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 132 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 132 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 132 times.
✓ Branch 41 taken 696 times.
✓ Branch 42 taken 132 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 132 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 132 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 132 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 132 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 132 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 132 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 132 times.
✓ Branch 63 taken 204 times.
✓ Branch 64 taken 132 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 132 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 132 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 132 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 132 times.
✓ Branch 73 taken 900 times.
✓ Branch 74 taken 132 times.
|
2250 | ] |
| 41 | end do | ||
| 42 | end do | ||
| 43 | |||
| 44 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 66 times.
✓ Branch 16 taken 12 times.
|
78 | c%indices = adj_ia |
| 45 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 12 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 12 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 12 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 132 times.
✓ Branch 27 taken 12 times.
✓ Branch 28 taken 264 times.
✓ Branch 29 taken 132 times.
|
408 | c%adj_ja = adj_ja |
| 46 | 12 | c%get_partial_left => get_partial_duvenaud_propagate_left | |
| 47 | 12 | c%get_partial_right => get_partial_duvenaud_propagate_right | |
| 48 | 12 | c%get_partial_left_val => get_partial_duvenaud_propagate_left_val | |
| 49 | 12 | c%get_partial_right_val => get_partial_duvenaud_propagate_right_val | |
| 50 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 9 times.
|
12 | if(vertex_features%requires_grad .or. edge_features%requires_grad)then |
| 51 | 3 | c%requires_grad = .true. | |
| 52 | 3 | c%is_forward = vertex_features%is_forward .or. edge_features%is_forward | |
| 53 | 3 | c%operation = 'duvenaud_propagate' | |
| 54 | 3 | c%left_operand => vertex_features | |
| 55 | 3 | c%right_operand => edge_features | |
| 56 | 3 | c%owns_left_operand = vertex_features%is_temporary | |
| 57 | 3 | c%owns_right_operand = edge_features%is_temporary | |
| 58 | end if | ||
| 59 | 12 | end function duvenaud_propagate | |
| 60 | !------------------------------------------------------------------------------- | ||
| 61 | ✗ | function get_partial_duvenaud_propagate_left(this, upstream_grad) result(output) | |
| 62 | !! Gradient of duvenaud_propagate with respect to vertex_features. | ||
| 63 | implicit none | ||
| 64 | |||
| 65 | ! Arguments | ||
| 66 | class(array_type), intent(inout) :: this | ||
| 67 | !! Forward result node containing saved operands | ||
| 68 | type(array_type), intent(in) :: upstream_grad | ||
| 69 | !! Upstream gradient tensor | ||
| 70 | type(array_type) :: output | ||
| 71 | !! Gradient tensor for left operand | ||
| 72 | |||
| 73 | ! Local variables | ||
| 74 | logical :: right_is_temporary_local | ||
| 75 | !! Saved temporary-ownership flag for right operand | ||
| 76 | type(array_type), pointer :: ptr | ||
| 77 | !! Intermediate gradient tensor pointer | ||
| 78 | |||
| 79 | ✗ | right_is_temporary_local = this%right_operand%is_temporary | |
| 80 | ✗ | this%right_operand%is_temporary = .false. | |
| 81 | ptr => duvenaud_propagate( upstream_grad, this%right_operand, & | ||
| 82 | ✗ | this%indices, this%adj_ja ) | |
| 83 | ✗ | this%right_operand%is_temporary = right_is_temporary_local | |
| 84 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 85 | |||
| 86 | ✗ | end function get_partial_duvenaud_propagate_left | |
| 87 | !------------------------------------------------------------------------------- | ||
| 88 | ✗ | function get_partial_duvenaud_propagate_right(this, upstream_grad) result(output) | |
| 89 | !! Gradient of duvenaud_propagate with respect to edge_features. | ||
| 90 | implicit none | ||
| 91 | |||
| 92 | ! Arguments | ||
| 93 | class(array_type), intent(inout) :: this | ||
| 94 | !! Forward result node containing saved operands | ||
| 95 | type(array_type), intent(in) :: upstream_grad | ||
| 96 | !! Upstream gradient tensor | ||
| 97 | type(array_type) :: output | ||
| 98 | !! Gradient tensor for right operand | ||
| 99 | |||
| 100 | ! Local variables | ||
| 101 | logical :: left_is_temporary_local | ||
| 102 | !! Saved temporary-ownership flag for left operand | ||
| 103 | type(array_type), pointer :: ptr | ||
| 104 | !! Intermediate gradient tensor pointer | ||
| 105 | |||
| 106 | ✗ | left_is_temporary_local = this%left_operand%is_temporary | |
| 107 | ✗ | this%left_operand%is_temporary = .false. | |
| 108 | ptr => duvenaud_propagate( this%left_operand, upstream_grad, & | ||
| 109 | ✗ | this%indices, this%adj_ja ) | |
| 110 | ✗ | this%left_operand%is_temporary = left_is_temporary_local | |
| 111 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 112 | |||
| 113 | ✗ | end function get_partial_duvenaud_propagate_right | |
| 114 | !------------------------------------------------------------------------------- | ||
| 115 | ✗ | pure subroutine get_partial_duvenaud_propagate_left_val( & | |
| 116 | ✗ | this, upstream_grad, output & | |
| 117 | ) | ||
| 118 | !! In-place value gradient for duvenaud_propagate left operand. | ||
| 119 | implicit none | ||
| 120 | |||
| 121 | ! Arguments | ||
| 122 | class(array_type), intent(in) :: this | ||
| 123 | !! Forward result node containing saved operands | ||
| 124 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 125 | !! Upstream gradient values | ||
| 126 | real(real32), dimension(:,:), intent(out) :: output | ||
| 127 | !! Output gradient values for left operand | ||
| 128 | |||
| 129 | ! Local variables | ||
| 130 | integer :: v, w, num_features, num_elements | ||
| 131 | !! Loop indices and operand shape values | ||
| 132 | |||
| 133 | ✗ | num_features = size(this%left_operand%val,1) | |
| 134 | ✗ | num_elements = size(this%left_operand%val,2) | |
| 135 | ✗ | output = 0._real32 | |
| 136 | ✗ | do concurrent(v=1:num_elements) | |
| 137 | ✗ | do w = this%indices(v), this%indices(v+1)-1 | |
| 138 | ✗ | output(:,this%adj_ja(1,w)) = output(:,this%adj_ja(1,w)) + & | |
| 139 | ✗ | [ upstream_grad(1:num_features, v) ] | |
| 140 | end do | ||
| 141 | end do | ||
| 142 | ✗ | end subroutine get_partial_duvenaud_propagate_left_val | |
| 143 | !------------------------------------------------------------------------------- | ||
| 144 | ✗ | pure subroutine get_partial_duvenaud_propagate_right_val( & | |
| 145 | ✗ | this, upstream_grad, output & | |
| 146 | ) | ||
| 147 | !! In-place value gradient for duvenaud_propagate right operand. | ||
| 148 | implicit none | ||
| 149 | |||
| 150 | ! Arguments | ||
| 151 | class(array_type), intent(in) :: this | ||
| 152 | !! Forward result node containing saved operands | ||
| 153 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 154 | !! Upstream gradient values | ||
| 155 | real(real32), dimension(:,:), intent(out) :: output | ||
| 156 | !! Output gradient values for right operand | ||
| 157 | |||
| 158 | ! Local variables | ||
| 159 | integer :: v, w, num_features, num_elements | ||
| 160 | !! Loop indices and operand shape values | ||
| 161 | |||
| 162 | ✗ | num_features = size(this%left_operand%val,1) | |
| 163 | ✗ | num_elements = size(this%left_operand%val,2) | |
| 164 | ✗ | output = 0._real32 | |
| 165 | ✗ | do concurrent(v=1:num_elements) | |
| 166 | ✗ | do w = this%indices(v), this%indices(v+1)-1 | |
| 167 | ✗ | output(:,this%adj_ja(2,w)) = output(:,this%adj_ja(2,w)) + & | |
| 168 | ✗ | [ upstream_grad(num_features+1:, v) ] | |
| 169 | end do | ||
| 170 | end do | ||
| 171 | ✗ | end subroutine get_partial_duvenaud_propagate_right_val | |
| 172 | !############################################################################### | ||
| 173 | |||
| 174 | |||
| 175 | !############################################################################### | ||
| 176 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c) |
| 177 | !! Update the message passing layer | ||
| 178 | implicit none | ||
| 179 | |||
| 180 | ! Arguments | ||
| 181 | class(array_type), intent(in), target :: a | ||
| 182 | !! Aggregated neighbour features | ||
| 183 | class(array_type), intent(in), target :: weight | ||
| 184 | !! Packed degree-conditioned weight tensor | ||
| 185 | ! real(real32), dimension(:,:,:), intent(in) :: weight | ||
| 186 | integer, dimension(:), intent(in) :: adj_ia | ||
| 187 | !! CSR row pointers | ||
| 188 | integer, intent(in) :: min_degree, max_degree | ||
| 189 | !! Minimum and maximum degree buckets | ||
| 190 | type(array_type), pointer :: c | ||
| 191 | !! Degree-conditioned updated feature tensor | ||
| 192 | type(array_type), pointer :: weight_array | ||
| 193 | !! Reserved pointer for weight reshaping operations | ||
| 194 | |||
| 195 | ! Local variables | ||
| 196 | integer :: v, i, d | ||
| 197 | !! Loop indices and degree bucket index | ||
| 198 | integer :: interval | ||
| 199 | !! Flat parameter interval for one degree bucket | ||
| 200 | 12 | real(real32), pointer :: w_ptr(:,:) | |
| 201 | !! 2D view over selected degree-specific weight matrix | ||
| 202 | |||
| 203 |
3/4✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✓ Branch 4 taken 24 times.
✓ Branch 5 taken 12 times.
|
36 | c => a%create_result(array_shape=[weight%shape(1), size(a%val,2)]) |
| 204 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
|
12 | interval = weight%shape(1) * weight%shape(2) |
| 205 |
2/2✓ Branch 0 taken 54 times.
✓ Branch 1 taken 12 times.
|
66 | do v = 1, size(a%val,2) |
| 206 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 54 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 54 times.
|
108 | d = max( min_degree, min( adj_ia(v+1) - adj_ia(v), max_degree ) ) - & |
| 207 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 54 times.
|
108 | min_degree + 1 |
| 208 | 216 | w_ptr(1:weight%shape(1), 1:weight%shape(2)) => & | |
| 209 |
11/22✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 54 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 54 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 54 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 54 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 54 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 54 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 54 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 54 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 54 times.
|
54 | weight%val(interval*(d-1)+1:interval*d,1) |
| 210 |
17/30✗ Branch 0 not taken.
✓ Branch 1 taken 54 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 54 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 54 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 54 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 54 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 54 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 54 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 54 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 54 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 54 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 54 times.
✓ Branch 32 taken 372 times.
✓ Branch 33 taken 54 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 54 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 54 times.
✓ Branch 41 taken 288 times.
✓ Branch 42 taken 54 times.
|
726 | c%val(:,v) = matmul(w_ptr, a%val(:,v) / real(d, real32)) |
| 211 | end do | ||
| 212 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 12 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 12 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 12 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 12 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 66 times.
✓ Branch 16 taken 12 times.
|
78 | c%indices = adj_ia |
| 213 | |||
| 214 | 12 | c%get_partial_left => get_partial_duvenaud_update_weight | |
| 215 | 12 | c%get_partial_right => get_partial_duvenaud_update | |
| 216 | 12 | c%get_partial_left_val => get_partial_duvenaud_update_weight_val | |
| 217 | 12 | c%get_partial_right_val => get_partial_duvenaud_update_val | |
| 218 |
1/2✓ Branch 0 taken 12 times.
✗ Branch 1 not taken.
|
12 | if(a%requires_grad .or. weight%requires_grad)then |
| 219 | 12 | c%requires_grad = .true. | |
| 220 | 12 | c%is_forward = a%is_forward .or. weight%is_forward | |
| 221 | 12 | c%operation = 'duvenaud_update' | |
| 222 | 12 | c%right_operand => a | |
| 223 | 12 | c%left_operand => weight | |
| 224 | 12 | c%owns_right_operand = a%is_temporary | |
| 225 | 12 | c%owns_left_operand = weight%is_temporary | |
| 226 | end if | ||
| 227 | |||
| 228 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 12 times.
|
24 | end function duvenaud_update |
| 229 | !------------------------------------------------------------------------------- | ||
| 230 | ✗ | function get_partial_duvenaud_update(this, upstream_grad) result(output) | |
| 231 | !! Gradient of duvenaud_update with respect to input features. | ||
| 232 | implicit none | ||
| 233 | |||
| 234 | ! Arguments | ||
| 235 | class(array_type), intent(inout) :: this | ||
| 236 | !! Forward result node containing saved operands | ||
| 237 | type(array_type), intent(in) :: upstream_grad | ||
| 238 | !! Upstream gradient tensor | ||
| 239 | type(array_type) :: output | ||
| 240 | !! Gradient tensor for right operand (input features) | ||
| 241 | |||
| 242 | ! Local variables | ||
| 243 | logical :: left_is_temporary_local | ||
| 244 | !! Saved temporary-ownership flag for left operand | ||
| 245 | type(array_type), pointer :: ptr | ||
| 246 | !! Intermediate gradient tensor pointer | ||
| 247 | |||
| 248 | ✗ | left_is_temporary_local = this%left_operand%is_temporary | |
| 249 | ✗ | this%left_operand%is_temporary = .false. | |
| 250 | ptr => duvenaud_update( upstream_grad, this%left_operand, & | ||
| 251 | ✗ | this%indices, this%left_operand%indices(1), this%left_operand%indices(2) ) | |
| 252 | ✗ | this%left_operand%is_temporary = left_is_temporary_local | |
| 253 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 254 | |||
| 255 | ✗ | end function get_partial_duvenaud_update | |
| 256 | !------------------------------------------------------------------------------- | ||
| 257 | ✗ | function get_partial_duvenaud_update_weight(this, upstream_grad) result(output) | |
| 258 | !! Gradient of duvenaud_update with respect to packed weights. | ||
| 259 | implicit none | ||
| 260 | |||
| 261 | ! Arguments | ||
| 262 | class(array_type), intent(inout) :: this | ||
| 263 | !! Forward result node containing saved operands | ||
| 264 | type(array_type), intent(in) :: upstream_grad | ||
| 265 | !! Upstream gradient tensor | ||
| 266 | type(array_type) :: output | ||
| 267 | !! Gradient tensor for left operand (weights) | ||
| 268 | |||
| 269 | ! Local variables | ||
| 270 | logical :: right_is_temporary_local | ||
| 271 | !! Saved temporary-ownership flag for right operand | ||
| 272 | type(array_type), pointer :: ptr | ||
| 273 | !! Intermediate gradient tensor pointer | ||
| 274 | |||
| 275 | ✗ | right_is_temporary_local = this%right_operand%is_temporary | |
| 276 | ✗ | this%right_operand%is_temporary = .false. | |
| 277 | ptr => duvenaud_update( this%right_operand, upstream_grad, & | ||
| 278 | ✗ | this%indices, this%left_operand%indices(1), this%left_operand%indices(2) ) | |
| 279 | ✗ | this%right_operand%is_temporary = right_is_temporary_local | |
| 280 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 281 | |||
| 282 | ✗ | end function get_partial_duvenaud_update_weight | |
| 283 | !------------------------------------------------------------------------------- | ||
| 284 | ✗ | pure subroutine get_partial_duvenaud_update_val( & | |
| 285 | ✗ | this, upstream_grad, output & | |
| 286 | ) | ||
| 287 | !! In-place value gradient for duvenaud_update input features. | ||
| 288 | implicit none | ||
| 289 | |||
| 290 | ! Arguments | ||
| 291 | class(array_type), intent(in) :: this | ||
| 292 | !! Forward result node containing saved operands | ||
| 293 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 294 | !! Upstream gradient values | ||
| 295 | real(real32), dimension(:,:), intent(out) :: output | ||
| 296 | !! Output gradient values for input features | ||
| 297 | |||
| 298 | ! Local variables | ||
| 299 | integer :: v, d | ||
| 300 | !! Loop index and degree bucket index | ||
| 301 | integer :: interval, num_output_features, num_input_features | ||
| 302 | !! Flattening interval and matrix dimensions | ||
| 303 | integer :: min_degree, max_degree | ||
| 304 | !! Degree bucket limits | ||
| 305 | ✗ | real(real32), dimension(size(upstream_grad,1), this%right_operand%shape(1)) :: tmp | |
| 306 | !! Temporary reshaped weight matrix for one degree bucket | ||
| 307 | |||
| 308 | ✗ | output = 0._real32 | |
| 309 | ✗ | num_output_features = size(upstream_grad,1) | |
| 310 | ✗ | num_input_features = this%right_operand%shape(1) | |
| 311 | ✗ | interval = num_output_features * num_input_features | |
| 312 | ✗ | min_degree = this%left_operand%indices(1) | |
| 313 | ✗ | max_degree = this%left_operand%indices(2) | |
| 314 | ✗ | do concurrent(v=1:size(upstream_grad,2)) | |
| 315 | d = max( & | ||
| 316 | min_degree, & | ||
| 317 | ✗ | min(this%indices(v+1) - this%indices(v), max_degree ) & | |
| 318 | ✗ | ) - min_degree + 1 | |
| 319 | ✗ | tmp = reshape(this%left_operand%val((d-1)*interval+1:d*interval,1), & | |
| 320 | ✗ | [num_output_features, num_input_features] ) | |
| 321 | ✗ | output(:,v) = matmul(upstream_grad(:,v), tmp) / real(d, real32) | |
| 322 | end do | ||
| 323 | |||
| 324 | ✗ | end subroutine get_partial_duvenaud_update_val | |
| 325 | !------------------------------------------------------------------------------- | ||
| 326 | 5 | pure subroutine get_partial_duvenaud_update_weight_val( & | |
| 327 |
2/4✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
|
5 | this, upstream_grad, output & |
| 328 | ) | ||
| 329 | !! In-place value gradient for duvenaud_update packed weights. | ||
| 330 | implicit none | ||
| 331 | |||
| 332 | ! Arguments | ||
| 333 | class(array_type), intent(in) :: this | ||
| 334 | !! Forward result node containing saved operands | ||
| 335 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 336 | !! Upstream gradient values | ||
| 337 | real(real32), dimension(:,:), intent(out) :: output | ||
| 338 | !! Output gradient values for packed weights | ||
| 339 | |||
| 340 | ! Local variables | ||
| 341 | integer :: v, i, j, d_offset, d_val | ||
| 342 | !! Loop indices, degree offset and degree bucket index | ||
| 343 | integer :: interval, num_output_features, num_input_features | ||
| 344 | !! Flattening interval and matrix dimensions | ||
| 345 | integer :: min_degree, max_degree | ||
| 346 | !! Degree bucket limits | ||
| 347 | |||
| 348 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✓ Branch 18 taken 25 times.
✓ Branch 19 taken 5 times.
✓ Branch 20 taken 8000 times.
✓ Branch 21 taken 25 times.
|
8030 | output = 0._real32 |
| 349 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
|
5 | num_output_features = size(upstream_grad,1) |
| 350 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
|
5 | num_input_features = this%right_operand%shape(1) |
| 351 | 5 | interval = num_output_features * num_input_features | |
| 352 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
|
5 | min_degree = this%left_operand%indices(1) |
| 353 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
|
5 | max_degree = this%left_operand%indices(2) |
| 354 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
|
5 | do concurrent(v=1:size(upstream_grad,2)) |
| 355 | d_val = max( & | ||
| 356 | min_degree, & | ||
| 357 | 100 | min(this%indices(v+1) - this%indices(v), max_degree ) & | |
| 358 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 25 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 25 times.
|
25 | ) - min_degree + 1 |
| 359 | 25 | d_offset = (d_val - 1) * interval | |
| 360 |
2/2✓ Branch 0 taken 25 times.
✓ Branch 1 taken 5 times.
|
30 | do concurrent(i = 1:num_output_features, j=1:num_input_features) |
| 361 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2000 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2000 times.
|
4000 | output(d_offset+i+num_output_features*(j-1),1) = & |
| 362 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2000 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2000 times.
|
4000 | output(d_offset+i+num_output_features*(j-1),1) + & |
| 363 |
2/4✗ Branch 1 not taken.
✓ Branch 2 taken 2000 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2000 times.
|
4000 | upstream_grad(i,v) * this%right_operand%val(j,v) / & |
| 364 |
12/20✓ Branch 0 taken 250 times.
✓ Branch 1 taken 25 times.
✓ Branch 2 taken 2000 times.
✓ Branch 3 taken 250 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2000 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2000 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2000 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2000 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2000 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2000 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2000 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2000 times.
|
12275 | real(d_val, real32) |
| 365 | end do | ||
| 366 | end do | ||
| 367 | |||
| 368 | 5 | end subroutine get_partial_duvenaud_update_weight_val | |
| 369 | !############################################################################### | ||
| 370 | |||
| 371 | end submodule athena__diffstruc_extd_submodule_msgpass_duvenaud | ||
| 372 |