| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__diffstruc_extd | ||
| 2 | !! Module for extended differential structure types for Athena | ||
| 3 | use coreutils, only: real32 | ||
| 4 | use diffstruc, only: array_type | ||
| 5 | use athena__misc_types, only: facets_type | ||
| 6 | implicit none | ||
| 7 | |||
| 8 | |||
| 9 | private | ||
| 10 | |||
| 11 | public :: array_ptr_type | ||
| 12 | public :: add_layers, concat_layers | ||
| 13 | public :: add_bias | ||
| 14 | public :: piecewise, softmax, swish | ||
| 15 | public :: huber | ||
| 16 | public :: avgpool1d, avgpool2d, avgpool3d | ||
| 17 | public :: maxpool1d, maxpool2d, maxpool3d | ||
| 18 | public :: pad1d, pad2d, pad3d | ||
| 19 | public :: merge_over_channels | ||
| 20 | public :: batchnorm_array_type, batchnorm, batchnorm_inference | ||
| 21 | public :: conv1d, conv2d, conv3d | ||
| 22 | public :: kipf_propagate, kipf_update | ||
| 23 | public :: duvenaud_propagate, duvenaud_update | ||
| 24 | public :: gno_kernel_eval, gno_aggregate | ||
| 25 | public :: lno_encode, lno_decode, elem_scale | ||
| 26 | public :: ono_encode, ono_decode | ||
| 27 | |||
| 28 | |||
| 29 | type, extends(array_type) :: batchnorm_array_type | ||
| 30 | real(real32), dimension(:), allocatable :: mean | ||
| 31 | real(real32), dimension(:), allocatable :: variance | ||
| 32 | real(real32) :: epsilon | ||
| 33 | end type batchnorm_array_type | ||
| 34 | |||
| 35 | |||
| 36 | !------------------------------------------------------------------------------- | ||
| 37 | ! Array container types | ||
| 38 | !------------------------------------------------------------------------------- | ||
| 39 | type :: array_ptr_type | ||
| 40 | type(array_type), pointer :: array(:,:) | ||
| 41 | end type array_ptr_type | ||
| 42 | |||
| 43 | ! Operator interfaces | ||
| 44 | !----------------------------------------------------------------------------- | ||
| 45 | interface add_layers | ||
| 46 | module function add_array_ptr(a, idx1, idx2) result(c) | ||
| 47 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 48 | integer, intent(in) :: idx1, idx2 | ||
| 49 | type(array_type), pointer :: c | ||
| 50 | end function add_array_ptr | ||
| 51 | end interface | ||
| 52 | |||
| 53 | interface concat_layers | ||
| 54 | module function concat_array_ptr(a, idx1, idx2, dim) result(c) | ||
| 55 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 56 | integer, intent(in) :: idx1, idx2, dim | ||
| 57 | type(array_type), pointer :: c | ||
| 58 | end function concat_array_ptr | ||
| 59 | end interface | ||
| 60 | !------------------------------------------------------------------------------- | ||
| 61 | |||
| 62 | |||
| 63 | !------------------------------------------------------------------------------- | ||
| 64 | ! Activation functions and other operations | ||
| 65 | !------------------------------------------------------------------------------- | ||
| 66 | interface | ||
| 67 | module function add_bias(input, bias, dim, dim_act_on_shape) result(output) | ||
| 68 | class(array_type), intent(in), target :: input | ||
| 69 | class(array_type), intent(in), target :: bias | ||
| 70 | integer, intent(in) :: dim | ||
| 71 | logical, intent(in), optional :: dim_act_on_shape | ||
| 72 | type(array_type), pointer :: output | ||
| 73 | end function add_bias | ||
| 74 | end interface | ||
| 75 | |||
| 76 | interface piecewise | ||
| 77 | module function piecewise_array(input, gradient, limit) result( output ) | ||
| 78 | class(array_type), intent(in), target :: input | ||
| 79 | real(real32), intent(in) :: gradient | ||
| 80 | real(real32), intent(in) :: limit | ||
| 81 | type(array_type), pointer :: output | ||
| 82 | end function piecewise_array | ||
| 83 | end interface | ||
| 84 | |||
| 85 | interface softmax | ||
| 86 | module function softmax_array(input, dim) result(output) | ||
| 87 | class(array_type), intent(in), target :: input | ||
| 88 | integer, intent(in) :: dim | ||
| 89 | type(array_type), pointer :: output | ||
| 90 | end function softmax_array | ||
| 91 | end interface | ||
| 92 | |||
| 93 | interface swish | ||
| 94 | module function swish_array(input, beta) result(output) | ||
| 95 | class(array_type), intent(in), target :: input | ||
| 96 | real(real32), intent(in) :: beta | ||
| 97 | type(array_type), pointer :: output | ||
| 98 | end function swish_array | ||
| 99 | end interface | ||
| 100 | !------------------------------------------------------------------------------- | ||
| 101 | |||
| 102 | |||
| 103 | !------------------------------------------------------------------------------- | ||
| 104 | ! Loss functions | ||
| 105 | !------------------------------------------------------------------------------- | ||
| 106 | interface huber | ||
| 107 | module function huber_array(delta, gamma) result( output ) | ||
| 108 | class(array_type), intent(in), target :: delta | ||
| 109 | real(real32), intent(in) :: gamma | ||
| 110 | type(array_type), pointer :: output | ||
| 111 | end function huber_array | ||
| 112 | end interface | ||
| 113 | !------------------------------------------------------------------------------- | ||
| 114 | |||
| 115 | |||
| 116 | !------------------------------------------------------------------------------- | ||
| 117 | ! Layer operations | ||
| 118 | !------------------------------------------------------------------------------- | ||
| 119 | interface | ||
| 120 | module function avgpool1d(input, pool_size, stride) result(output) | ||
| 121 | type(array_type), intent(in), target :: input | ||
| 122 | integer, intent(in) :: pool_size | ||
| 123 | integer, intent(in) :: stride | ||
| 124 | type(array_type), pointer :: output | ||
| 125 | end function avgpool1d | ||
| 126 | |||
| 127 | module function avgpool2d(input, pool_size, stride) result(output) | ||
| 128 | type(array_type), intent(in), target :: input | ||
| 129 | integer, dimension(2), intent(in) :: pool_size | ||
| 130 | integer, dimension(2), intent(in) :: stride | ||
| 131 | type(array_type), pointer :: output | ||
| 132 | end function avgpool2d | ||
| 133 | |||
| 134 | module function avgpool3d(input, pool_size, stride) result(output) | ||
| 135 | type(array_type), intent(in), target :: input | ||
| 136 | integer, dimension(3), intent(in) :: pool_size | ||
| 137 | integer, dimension(3), intent(in) :: stride | ||
| 138 | type(array_type), pointer :: output | ||
| 139 | end function avgpool3d | ||
| 140 | end interface | ||
| 141 | |||
| 142 | interface | ||
| 143 | module function maxpool1d(input, pool_size, stride) result(output) | ||
| 144 | type(array_type), intent(in), target :: input | ||
| 145 | integer, intent(in) :: pool_size | ||
| 146 | integer, intent(in) :: stride | ||
| 147 | type(array_type), pointer :: output | ||
| 148 | end function maxpool1d | ||
| 149 | |||
| 150 | module function maxpool2d(input, pool_size, stride) result(output) | ||
| 151 | type(array_type), intent(in), target :: input | ||
| 152 | integer, dimension(2), intent(in) :: pool_size | ||
| 153 | integer, dimension(2), intent(in) :: stride | ||
| 154 | type(array_type), pointer :: output | ||
| 155 | end function maxpool2d | ||
| 156 | |||
| 157 | module function maxpool3d(input, pool_size, stride) result(output) | ||
| 158 | type(array_type), intent(in), target :: input | ||
| 159 | integer, dimension(3), intent(in) :: pool_size | ||
| 160 | integer, dimension(3), intent(in) :: stride | ||
| 161 | type(array_type), pointer :: output | ||
| 162 | end function maxpool3d | ||
| 163 | end interface | ||
| 164 | |||
| 165 | interface | ||
| 166 | module function pad1d(input, facets, pad_size, imethod) result(output) | ||
| 167 | type(array_type), intent(in), target :: input | ||
| 168 | type(facets_type), intent(in) :: facets | ||
| 169 | integer, intent(in) :: pad_size | ||
| 170 | integer, intent(in) :: imethod | ||
| 171 | type(array_type), pointer :: output | ||
| 172 | end function pad1d | ||
| 173 | |||
| 174 | module function pad2d(input, facets, pad_size, imethod) result(output) | ||
| 175 | type(array_type), intent(in), target :: input | ||
| 176 | type(facets_type), dimension(2), intent(in) :: facets | ||
| 177 | integer, dimension(2), intent(in) :: pad_size | ||
| 178 | integer, intent(in) :: imethod | ||
| 179 | type(array_type), pointer :: output | ||
| 180 | end function pad2d | ||
| 181 | |||
| 182 | module function pad3d(input, facets, pad_size, imethod) result(output) | ||
| 183 | type(array_type), intent(in), target :: input | ||
| 184 | type(facets_type), dimension(3), intent(in) :: facets | ||
| 185 | integer, dimension(3), intent(in) :: pad_size | ||
| 186 | integer, intent(in) :: imethod | ||
| 187 | type(array_type), pointer :: output | ||
| 188 | end function pad3d | ||
| 189 | end interface | ||
| 190 | |||
| 191 | interface merge_over_channels | ||
| 192 | module function merge_scalar_over_channels(tsource, fsource, mask) result(output) | ||
| 193 | class(array_type), intent(in), target :: tsource | ||
| 194 | real(real32), intent(in) :: fsource | ||
| 195 | logical, dimension(:,:), intent(in) :: mask | ||
| 196 | type(array_type), pointer :: output | ||
| 197 | end function merge_scalar_over_channels | ||
| 198 | end interface | ||
| 199 | |||
| 200 | interface | ||
| 201 | module function batchnorm( & | ||
| 202 | input, params, momentum, mean, variance, epsilon & | ||
| 203 | ) result( output ) | ||
| 204 | class(array_type), intent(in), target :: input | ||
| 205 | class(array_type), intent(in), target :: params | ||
| 206 | real(real32), intent(in) :: momentum | ||
| 207 | real(real32), dimension(:), intent(in) :: mean | ||
| 208 | real(real32), dimension(:), intent(in) :: variance | ||
| 209 | real(real32), intent(in) :: epsilon | ||
| 210 | type(batchnorm_array_type), pointer :: output | ||
| 211 | end function batchnorm | ||
| 212 | |||
| 213 | module function batchnorm_inference( & | ||
| 214 | input, params, mean, variance, epsilon & | ||
| 215 | ) result( output ) | ||
| 216 | class(array_type), intent(in), target :: input | ||
| 217 | class(array_type), intent(in), target :: params | ||
| 218 | real(real32), dimension(:), intent(in) :: mean | ||
| 219 | real(real32), dimension(:), intent(in) :: variance | ||
| 220 | real(real32), intent(in) :: epsilon | ||
| 221 | type(batchnorm_array_type), pointer :: output | ||
| 222 | end function batchnorm_inference | ||
| 223 | end interface | ||
| 224 | |||
| 225 | interface | ||
| 226 | module function conv1d(input, kernel, stride, dilation) result(output) | ||
| 227 | type(array_type), intent(in), target :: input | ||
| 228 | type(array_type), intent(in), target :: kernel | ||
| 229 | integer, intent(in) :: stride | ||
| 230 | integer, intent(in) :: dilation | ||
| 231 | type(array_type), pointer :: output | ||
| 232 | end function conv1d | ||
| 233 | |||
| 234 | module function conv2d(input, kernel, stride, dilation) result(output) | ||
| 235 | type(array_type), intent(in), target :: input | ||
| 236 | type(array_type), intent(in), target :: kernel | ||
| 237 | integer, dimension(2), intent(in) :: stride | ||
| 238 | integer, dimension(2), intent(in) :: dilation | ||
| 239 | type(array_type), pointer :: output | ||
| 240 | end function conv2d | ||
| 241 | |||
| 242 | module function conv3d(input, kernel, stride, dilation) result(output) | ||
| 243 | type(array_type), intent(in), target :: input | ||
| 244 | type(array_type), intent(in), target :: kernel | ||
| 245 | integer, dimension(3), intent(in) :: stride | ||
| 246 | integer, dimension(3), intent(in) :: dilation | ||
| 247 | type(array_type), pointer :: output | ||
| 248 | end function conv3d | ||
| 249 | end interface | ||
| 250 | |||
| 251 | interface | ||
| 252 | module function kipf_propagate(vertex_features, adj_ia, adj_ja) result(c) | ||
| 253 | !! Propagate values from one autodiff array to another | ||
| 254 | class(array_type), intent(in), target :: vertex_features | ||
| 255 | integer, dimension(:), intent(in) :: adj_ia | ||
| 256 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 257 | type(array_type), pointer :: c | ||
| 258 | end function kipf_propagate | ||
| 259 | |||
| 260 | module function kipf_update(a, weight, adj_ia) result(c) | ||
| 261 | !! Update the message passing layer | ||
| 262 | class(array_type), intent(in), target :: a | ||
| 263 | class(array_type), intent(in), target :: weight | ||
| 264 | integer, dimension(:), intent(in) :: adj_ia | ||
| 265 | type(array_type), pointer :: c | ||
| 266 | end function kipf_update | ||
| 267 | end interface | ||
| 268 | |||
| 269 | interface | ||
| 270 | module function duvenaud_propagate( & | ||
| 271 | vertex_features, edge_features, adj_ia, adj_ja & | ||
| 272 | ) result(c) | ||
| 273 | !! Duvenaud message passing function | ||
| 274 | class(array_type), intent(in), target :: vertex_features | ||
| 275 | class(array_type), intent(in), target :: edge_features | ||
| 276 | integer, dimension(:), intent(in) :: adj_ia | ||
| 277 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 278 | type(array_type), pointer :: c | ||
| 279 | end function duvenaud_propagate | ||
| 280 | |||
| 281 | module function duvenaud_update( & | ||
| 282 | a, weight, adj_ia, min_degree, max_degree & | ||
| 283 | ) result(c) | ||
| 284 | !! Duvenaud update function | ||
| 285 | class(array_type), intent(in), target :: a | ||
| 286 | class(array_type), intent(in), target :: weight | ||
| 287 | integer, dimension(:), intent(in) :: adj_ia | ||
| 288 | integer, intent(in) :: min_degree, max_degree | ||
| 289 | type(array_type), pointer :: c | ||
| 290 | end function duvenaud_update | ||
| 291 | end interface | ||
| 292 | |||
| 293 | interface | ||
| 294 | module function gno_kernel_eval( & | ||
| 295 | coords, kernel_params, adj_ia, adj_ja, & | ||
| 296 | coord_dim, kernel_hidden, F_in, F_out & | ||
| 297 | ) result(c) | ||
| 298 | !! Evaluate GNO kernel MLP on every edge | ||
| 299 | class(array_type), intent(in), target :: coords | ||
| 300 | class(array_type), intent(in), target :: kernel_params | ||
| 301 | integer, dimension(:), intent(in) :: adj_ia | ||
| 302 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 303 | integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out | ||
| 304 | type(array_type), pointer :: c | ||
| 305 | end function gno_kernel_eval | ||
| 306 | |||
| 307 | module function gno_aggregate( & | ||
| 308 | features, edge_kernels, adj_ia, adj_ja, F_in, F_out & | ||
| 309 | ) result(c) | ||
| 310 | !! Aggregate neighbour messages using per-edge kernels | ||
| 311 | class(array_type), intent(in), target :: features | ||
| 312 | class(array_type), intent(in), target :: edge_kernels | ||
| 313 | integer, dimension(:), intent(in) :: adj_ia | ||
| 314 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 315 | integer, intent(in) :: F_in, F_out | ||
| 316 | type(array_type), pointer :: c | ||
| 317 | end function gno_aggregate | ||
| 318 | end interface | ||
| 319 | |||
| 320 | interface | ||
| 321 | module function lno_encode( & | ||
| 322 | input, poles, num_inputs, num_modes & | ||
| 323 | ) result(c) | ||
| 324 | !! Encode input via Laplace basis: E(mu) @ u | ||
| 325 | class(array_type), intent(in), target :: input | ||
| 326 | class(array_type), intent(in), target :: poles | ||
| 327 | integer, intent(in) :: num_inputs, num_modes | ||
| 328 | type(array_type), pointer :: c | ||
| 329 | end function lno_encode | ||
| 330 | |||
| 331 | module function lno_decode( & | ||
| 332 | spectral, poles, num_outputs, num_modes & | ||
| 333 | ) result(c) | ||
| 334 | !! Decode via Laplace basis: D(mu) @ spectral | ||
| 335 | class(array_type), intent(in), target :: spectral | ||
| 336 | class(array_type), intent(in), target :: poles | ||
| 337 | integer, intent(in) :: num_outputs, num_modes | ||
| 338 | type(array_type), pointer :: c | ||
| 339 | end function lno_decode | ||
| 340 | end interface | ||
| 341 | |||
| 342 | interface | ||
| 343 | module function elem_scale(input, scale) result(c) | ||
| 344 | !! Element-wise multiply: out[i,s] = input[i,s] * scale[i,1] | ||
| 345 | !! Correctly handles non-sample-dependent scale vectors. | ||
| 346 | class(array_type), intent(in), target :: input | ||
| 347 | class(array_type), intent(in), target :: scale | ||
| 348 | type(array_type), pointer :: c | ||
| 349 | end function elem_scale | ||
| 350 | end interface | ||
| 351 | |||
| 352 | interface | ||
| 353 | module function ono_encode( & | ||
| 354 | input, basis_weights, num_inputs, num_basis & | ||
| 355 | ) result(c) | ||
| 356 | !! Encode via orthogonal basis: Q(B)^T @ u | ||
| 357 | class(array_type), intent(in), target :: input | ||
| 358 | class(array_type), intent(in), target :: basis_weights | ||
| 359 | integer, intent(in) :: num_inputs, num_basis | ||
| 360 | type(array_type), pointer :: c | ||
| 361 | end function ono_encode | ||
| 362 | |||
| 363 | module function ono_decode( & | ||
| 364 | mixed, basis_weights, num_inputs, num_basis & | ||
| 365 | ) result(c) | ||
| 366 | !! Decode via orthogonal basis: Q(B) @ mixed | ||
| 367 | class(array_type), intent(in), target :: mixed | ||
| 368 | class(array_type), intent(in), target :: basis_weights | ||
| 369 | integer, intent(in) :: num_inputs, num_basis | ||
| 370 | type(array_type), pointer :: c | ||
| 371 | end function ono_decode | ||
| 372 | end interface | ||
| 373 | !------------------------------------------------------------------------------- | ||
| 374 | |||
| 375 |
39/61✓ Branch 0 taken 36 times.
✓ Branch 1 taken 33 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 84 times.
✓ Branch 4 taken 33 times.
✓ Branch 5 taken 36 times.
✓ Branch 6 taken 48 times.
✓ Branch 7 taken 36 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 84 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 36 times.
✓ Branch 12 taken 48 times.
✓ Branch 13 taken 36 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 84 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 36 times.
✓ Branch 18 taken 48 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 48 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 48 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 48 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 48 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 48 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 48 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 48 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 33 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 33 times.
✓ Branch 45 taken 33 times.
✓ Branch 46 taken 33 times.
✓ Branch 47 taken 48 times.
✓ Branch 48 taken 33 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 48 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 48 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 48 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 48 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 48 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 48 times.
✓ Branch 67 taken 33 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 21 times.
✓ Branch 70 taken 12 times.
✓ Branch 71 taken 33 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 21 times.
✓ Branch 74 taken 12 times.
|
396 | end module athena__diffstruc_extd |
| 376 |