| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_msgpass_duvenaud | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | − | module function duvenaud_propagate( & | |
| 8 | − | vertex_features, edge_features, adj_ia, adj_ja & | |
| 9 | ) result(c) | ||
| 10 | !! Propagate values from one autodiff array to another | ||
| 11 | implicit none | ||
| 12 | class(array_type), intent(in), target :: vertex_features, edge_features | ||
| 13 | integer, dimension(:), intent(in) :: adj_ia | ||
| 14 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 15 | type(array_type), pointer :: c | ||
| 16 | |||
| 17 | integer :: v, w | ||
| 18 | |||
| 19 | c => vertex_features%create_result( & | ||
| 20 | array_shape = [ & | ||
| 21 | size(vertex_features%val,1) + size(edge_features%val,1), & | ||
| 22 | size(vertex_features%val,2) & | ||
| 23 | ] & | ||
| 24 | − | ) | |
| 25 | ! propagate 1D array by using shape to swap dimensions | ||
| 26 | − | do concurrent(v=1:size(vertex_features%val,2)) | |
| 27 | − | c%val(:,v) = 0.0_real32 | |
| 28 | − | do w = adj_ia(v), adj_ia(v+1)-1 | |
| 29 | − | c%val(:,v) = c%val(:,v) + [ & | |
| 30 | − | vertex_features%val(:, adj_ja(1, w)), & | |
| 31 | − | edge_features%val(:, adj_ja(2, w)) & | |
| 32 | − | ] | |
| 33 | end do | ||
| 34 | end do | ||
| 35 | |||
| 36 | − | c%indices = adj_ia | |
| 37 | − | c%adj_ja = adj_ja | |
| 38 | − | c%get_partial_left => get_partial_duvenaud_propagate_left | |
| 39 | − | c%get_partial_right => get_partial_duvenaud_propagate_right | |
| 40 | − | c%get_partial_left_val => get_partial_duvenaud_propagate_left_val | |
| 41 | − | c%get_partial_right_val => get_partial_duvenaud_propagate_right_val | |
| 42 | − | if(vertex_features%requires_grad .or. edge_features%requires_grad) then | |
| 43 | − | c%requires_grad = .true. | |
| 44 | − | c%is_forward = vertex_features%is_forward .or. edge_features%is_forward | |
| 45 | − | c%operation = 'duvenaud_propagate' | |
| 46 | − | c%left_operand => vertex_features | |
| 47 | − | c%right_operand => edge_features | |
| 48 | − | c%owns_left_operand = vertex_features%is_temporary | |
| 49 | − | c%owns_right_operand = edge_features%is_temporary | |
| 50 | end if | ||
| 51 | − | end function duvenaud_propagate | |
| 52 | !------------------------------------------------------------------------------- | ||
| 53 | − | function get_partial_duvenaud_propagate_left(this, upstream_grad) result(output) | |
| 54 | implicit none | ||
| 55 | class(array_type), intent(inout) :: this | ||
| 56 | type(array_type), intent(in) :: upstream_grad | ||
| 57 | type(array_type) :: output | ||
| 58 | |||
| 59 | logical :: right_is_temporary_local | ||
| 60 | type(array_type), pointer :: ptr | ||
| 61 | |||
| 62 | − | right_is_temporary_local = this%right_operand%is_temporary | |
| 63 | − | this%right_operand%is_temporary = .false. | |
| 64 | ptr => duvenaud_propagate( upstream_grad, this%right_operand, & | ||
| 65 | − | this%indices, this%adj_ja ) | |
| 66 | − | this%right_operand%is_temporary = right_is_temporary_local | |
| 67 | − | call output%assign_and_deallocate_source(ptr) | |
| 68 | |||
| 69 | − | end function get_partial_duvenaud_propagate_left | |
| 70 | !------------------------------------------------------------------------------- | ||
| 71 | − | function get_partial_duvenaud_propagate_right(this, upstream_grad) result(output) | |
| 72 | implicit none | ||
| 73 | class(array_type), intent(inout) :: this | ||
| 74 | type(array_type), intent(in) :: upstream_grad | ||
| 75 | type(array_type) :: output | ||
| 76 | |||
| 77 | logical :: left_is_temporary_local | ||
| 78 | type(array_type), pointer :: ptr | ||
| 79 | |||
| 80 | − | left_is_temporary_local = this%left_operand%is_temporary | |
| 81 | − | this%left_operand%is_temporary = .false. | |
| 82 | ptr => duvenaud_propagate( this%left_operand, upstream_grad, & | ||
| 83 | − | this%indices, this%adj_ja ) | |
| 84 | − | this%left_operand%is_temporary = left_is_temporary_local | |
| 85 | − | call output%assign_and_deallocate_source(ptr) | |
| 86 | |||
| 87 | − | end function get_partial_duvenaud_propagate_right | |
| 88 | !------------------------------------------------------------------------------- | ||
| 89 | − | pure subroutine get_partial_duvenaud_propagate_left_val( & | |
| 90 | − | this, upstream_grad, output & | |
| 91 | ) | ||
| 92 | implicit none | ||
| 93 | class(array_type), intent(in) :: this | ||
| 94 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 95 | real(real32), dimension(:,:), intent(out) :: output | ||
| 96 | |||
| 97 | integer :: v, w, num_features, num_elements | ||
| 98 | |||
| 99 | − | num_features = size(this%left_operand%val,1) | |
| 100 | − | num_elements = size(this%left_operand%val,2) | |
| 101 | − | output = 0._real32 | |
| 102 | − | do concurrent(v=1:num_elements) | |
| 103 | − | do w = this%indices(v), this%indices(v+1)-1 | |
| 104 | − | output(:,this%adj_ja(1,w)) = output(:,this%adj_ja(1,w)) + & | |
| 105 | − | [ upstream_grad(1:num_features, v) ] | |
| 106 | end do | ||
| 107 | end do | ||
| 108 | − | end subroutine get_partial_duvenaud_propagate_left_val | |
| 109 | !------------------------------------------------------------------------------- | ||
| 110 | − | pure subroutine get_partial_duvenaud_propagate_right_val( & | |
| 111 | − | this, upstream_grad, output & | |
| 112 | ) | ||
| 113 | implicit none | ||
| 114 | class(array_type), intent(in) :: this | ||
| 115 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 116 | real(real32), dimension(:,:), intent(out) :: output | ||
| 117 | |||
| 118 | integer :: v, w, num_features, num_elements | ||
| 119 | |||
| 120 | − | num_features = size(this%left_operand%val,1) | |
| 121 | − | num_elements = size(this%left_operand%val,2) | |
| 122 | − | output = 0._real32 | |
| 123 | − | do concurrent(v=1:num_elements) | |
| 124 | − | do w = this%indices(v), this%indices(v+1)-1 | |
| 125 | − | output(:,this%adj_ja(2,w)) = output(:,this%adj_ja(2,w)) + & | |
| 126 | − | [ upstream_grad(num_features+1:, v) ] | |
| 127 | end do | ||
| 128 | end do | ||
| 129 | − | end subroutine get_partial_duvenaud_propagate_right_val | |
| 130 | !############################################################################### | ||
| 131 | |||
| 132 | |||
| 133 | !############################################################################### | ||
| 134 | − | module function duvenaud_update(a, weight, adj_ia, min_degree, max_degree) result(c) | |
| 135 | !! Update the message passing layer | ||
| 136 | implicit none | ||
| 137 | class(array_type), intent(in), target :: a | ||
| 138 | class(array_type), intent(in), target :: weight | ||
| 139 | ! real(real32), dimension(:,:,:), intent(in) :: weight | ||
| 140 | integer, dimension(:), intent(in) :: adj_ia | ||
| 141 | integer, intent(in) :: min_degree, max_degree | ||
| 142 | type(array_type), pointer :: c | ||
| 143 | type(array_type), pointer :: weight_array | ||
| 144 | |||
| 145 | integer :: v, i, d | ||
| 146 | integer :: interval | ||
| 147 | − | real(real32), pointer :: w_ptr(:,:) | |
| 148 | |||
| 149 | − | c => a%create_result(array_shape=[weight%shape(1), size(a%val,2)]) | |
| 150 | − | interval = weight%shape(1) * weight%shape(2) | |
| 151 | − | do v = 1, size(a%val,2) | |
| 152 | − | d = max( min_degree, min( adj_ia(v+1) - adj_ia(v), max_degree ) ) - & | |
| 153 | − | min_degree + 1 | |
| 154 | − | w_ptr(1:weight%shape(1), 1:weight%shape(2)) => & | |
| 155 | − | weight%val(interval*(d-1)+1:interval*d,1) | |
| 156 | − | c%val(:,v) = matmul(w_ptr, a%val(:,v) / real(d, real32)) | |
| 157 | end do | ||
| 158 | − | c%indices = adj_ia | |
| 159 | |||
| 160 | − | c%get_partial_left => get_partial_duvenaud_update_weight | |
| 161 | − | c%get_partial_right => get_partial_duvenaud_update | |
| 162 | − | c%get_partial_left_val => get_partial_duvenaud_update_weight_val | |
| 163 | − | c%get_partial_right_val => get_partial_duvenaud_update_val | |
| 164 | − | if(a%requires_grad .or. weight%requires_grad) then | |
| 165 | − | c%requires_grad = .true. | |
| 166 | − | c%is_forward = a%is_forward .or. weight%is_forward | |
| 167 | − | c%operation = 'duvenaud_update' | |
| 168 | − | c%right_operand => a | |
| 169 | − | c%left_operand => weight | |
| 170 | − | c%owns_right_operand = a%is_temporary | |
| 171 | − | c%owns_left_operand = weight%is_temporary | |
| 172 | end if | ||
| 173 | |||
| 174 | − | end function duvenaud_update | |
| 175 | !------------------------------------------------------------------------------- | ||
| 176 | − | function get_partial_duvenaud_update(this, upstream_grad) result(output) | |
| 177 | class(array_type), intent(inout) :: this | ||
| 178 | type(array_type), intent(in) :: upstream_grad | ||
| 179 | type(array_type) :: output | ||
| 180 | logical :: left_is_temporary_local | ||
| 181 | type(array_type), pointer :: ptr | ||
| 182 | |||
| 183 | − | left_is_temporary_local = this%left_operand%is_temporary | |
| 184 | − | this%left_operand%is_temporary = .false. | |
| 185 | ptr => duvenaud_update( upstream_grad, this%left_operand, & | ||
| 186 | − | this%indices, this%left_operand%indices(1), this%left_operand%indices(2) ) | |
| 187 | − | this%left_operand%is_temporary = left_is_temporary_local | |
| 188 | − | call output%assign_and_deallocate_source(ptr) | |
| 189 | |||
| 190 | − | end function get_partial_duvenaud_update | |
| 191 | !------------------------------------------------------------------------------- | ||
| 192 | − | function get_partial_duvenaud_update_weight(this, upstream_grad) result(output) | |
| 193 | class(array_type), intent(inout) :: this | ||
| 194 | type(array_type), intent(in) :: upstream_grad | ||
| 195 | type(array_type) :: output | ||
| 196 | logical :: right_is_temporary_local | ||
| 197 | type(array_type), pointer :: ptr | ||
| 198 | |||
| 199 | − | right_is_temporary_local = this%right_operand%is_temporary | |
| 200 | − | this%right_operand%is_temporary = .false. | |
| 201 | ptr => duvenaud_update( this%right_operand, upstream_grad, & | ||
| 202 | − | this%indices, this%left_operand%indices(1), this%left_operand%indices(2) ) | |
| 203 | − | this%right_operand%is_temporary = right_is_temporary_local | |
| 204 | − | call output%assign_and_deallocate_source(ptr) | |
| 205 | |||
| 206 | − | end function get_partial_duvenaud_update_weight | |
| 207 | !------------------------------------------------------------------------------- | ||
| 208 | − | pure subroutine get_partial_duvenaud_update_val( & | |
| 209 | − | this, upstream_grad, output & | |
| 210 | ) | ||
| 211 | implicit none | ||
| 212 | class(array_type), intent(in) :: this | ||
| 213 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 214 | real(real32), dimension(:,:), intent(out) :: output | ||
| 215 | |||
| 216 | integer :: v, d | ||
| 217 | integer :: interval, num_output_features, num_input_features | ||
| 218 | integer :: min_degree, max_degree | ||
| 219 | − | real(real32), dimension(size(upstream_grad,1), this%right_operand%shape(1)) :: tmp | |
| 220 | |||
| 221 | − | output = 0._real32 | |
| 222 | − | num_output_features = size(upstream_grad,1) | |
| 223 | − | num_input_features = this%right_operand%shape(1) | |
| 224 | − | interval = num_output_features * num_input_features | |
| 225 | − | min_degree = this%left_operand%indices(1) | |
| 226 | − | max_degree = this%left_operand%indices(2) | |
| 227 | − | do concurrent(v=1:size(upstream_grad,2)) | |
| 228 | d = max( & | ||
| 229 | min_degree, & | ||
| 230 | − | min(this%indices(v+1) - this%indices(v), max_degree ) & | |
| 231 | − | ) - min_degree + 1 | |
| 232 | − | tmp = reshape(this%left_operand%val((d-1)*interval+1:d*interval,1), & | |
| 233 | − | [num_output_features, num_input_features] ) | |
| 234 | − | output(:,v) = matmul(upstream_grad(:,v), tmp) | |
| 235 | end do | ||
| 236 | |||
| 237 | − | end subroutine get_partial_duvenaud_update_val | |
| 238 | !------------------------------------------------------------------------------- | ||
| 239 | − | pure subroutine get_partial_duvenaud_update_weight_val( & | |
| 240 | − | this, upstream_grad, output & | |
| 241 | ) | ||
| 242 | implicit none | ||
| 243 | class(array_type), intent(in) :: this | ||
| 244 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 245 | real(real32), dimension(:,:), intent(out) :: output | ||
| 246 | |||
| 247 | integer :: v, i, j, d | ||
| 248 | integer :: interval, num_output_features, num_input_features | ||
| 249 | integer :: min_degree, max_degree | ||
| 250 | |||
| 251 | − | output = 0._real32 | |
| 252 | − | num_output_features = size(upstream_grad,1) | |
| 253 | − | num_input_features = this%right_operand%shape(1) | |
| 254 | − | interval = num_output_features * num_input_features | |
| 255 | − | min_degree = this%left_operand%indices(1) | |
| 256 | − | max_degree = this%left_operand%indices(2) | |
| 257 | − | do concurrent(v=1:size(upstream_grad,2)) | |
| 258 | d = ( max( & | ||
| 259 | min_degree, & | ||
| 260 | − | min(this%indices(v+1) - this%indices(v), max_degree ) & | |
| 261 | − | ) - min_degree ) * interval | |
| 262 | − | do concurrent(i = 1:num_output_features, j=1:num_input_features) | |
| 263 | − | output(d+i+num_output_features*(j-1),1) = & | |
| 264 | − | output(d+i+num_output_features*(j-1),1) + & | |
| 265 | − | upstream_grad(i,v) * this%right_operand%val(j,v) | |
| 266 | end do | ||
| 267 | end do | ||
| 268 | |||
| 269 | − | end subroutine get_partial_duvenaud_update_weight_val | |
| 270 | !############################################################################### | ||
| 271 | |||
| 272 | end submodule athena__diffstruc_extd_submodule_msgpass_duvenaud | ||
| 273 |