| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__diffstruc_extd | ||
| 2 | !! Module for extended differential structure types for Athena | ||
| 3 | use coreutils, only: real32 | ||
| 4 | use diffstruc, only: array_type | ||
| 5 | use athena__misc_types, only: facets_type | ||
| 6 | implicit none | ||
| 7 | |||
| 8 | |||
| 9 | private | ||
| 10 | |||
| 11 | public :: array_ptr_type | ||
| 12 | public :: add_layers, concat_layers | ||
| 13 | public :: add_bias | ||
| 14 | public :: piecewise, softmax, swish | ||
| 15 | public :: huber | ||
| 16 | public :: avgpool1d, avgpool2d, avgpool3d | ||
| 17 | public :: maxpool1d, maxpool2d, maxpool3d | ||
| 18 | public :: pad1d, pad2d, pad3d | ||
| 19 | public :: merge_over_channels | ||
| 20 | public :: batchnorm_array_type, batchnorm, batchnorm_inference | ||
| 21 | public :: conv1d, conv2d, conv3d | ||
| 22 | public :: kipf_propagate, kipf_update | ||
| 23 | public :: duvenaud_propagate, duvenaud_update | ||
| 24 | |||
| 25 | |||
| 26 | type, extends(array_type) :: batchnorm_array_type | ||
| 27 | real(real32), dimension(:), allocatable :: mean | ||
| 28 | real(real32), dimension(:), allocatable :: variance | ||
| 29 | real(real32) :: epsilon | ||
| 30 | end type batchnorm_array_type | ||
| 31 | |||
| 32 | |||
| 33 | !------------------------------------------------------------------------------- | ||
| 34 | ! Array container types | ||
| 35 | !------------------------------------------------------------------------------- | ||
| 36 | type :: array_ptr_type | ||
| 37 | type(array_type), pointer :: array(:,:) | ||
| 38 | end type array_ptr_type | ||
| 39 | |||
| 40 | ! Operator interfaces | ||
| 41 | !----------------------------------------------------------------------------- | ||
| 42 | interface add_layers | ||
| 43 | module function add_array_ptr(a, idx1, idx2) result(c) | ||
| 44 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 45 | integer, intent(in) :: idx1, idx2 | ||
| 46 | type(array_type), pointer :: c | ||
| 47 | end function add_array_ptr | ||
| 48 | end interface | ||
| 49 | |||
| 50 | interface concat_layers | ||
| 51 | module function concat_array_ptr(a, idx1, idx2, dim) result(c) | ||
| 52 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 53 | integer, intent(in) :: idx1, idx2, dim | ||
| 54 | type(array_type), pointer :: c | ||
| 55 | end function concat_array_ptr | ||
| 56 | end interface | ||
| 57 | !------------------------------------------------------------------------------- | ||
| 58 | |||
| 59 | |||
| 60 | !------------------------------------------------------------------------------- | ||
| 61 | ! Activation functions and other operations | ||
| 62 | !------------------------------------------------------------------------------- | ||
| 63 | interface | ||
| 64 | module function add_bias(input, bias, dim, dim_act_on_shape) result(output) | ||
| 65 | class(array_type), intent(in), target :: input | ||
| 66 | class(array_type), intent(in), target :: bias | ||
| 67 | integer, intent(in) :: dim | ||
| 68 | logical, intent(in), optional :: dim_act_on_shape | ||
| 69 | type(array_type), pointer :: output | ||
| 70 | end function add_bias | ||
| 71 | end interface | ||
| 72 | |||
| 73 | interface piecewise | ||
| 74 | module function piecewise_array(input, gradient, limit) result( output ) | ||
| 75 | class(array_type), intent(in), target :: input | ||
| 76 | real(real32), intent(in) :: gradient | ||
| 77 | real(real32), intent(in) :: limit | ||
| 78 | type(array_type), pointer :: output | ||
| 79 | end function piecewise_array | ||
| 80 | end interface | ||
| 81 | |||
| 82 | interface softmax | ||
| 83 | module function softmax_array(input, dim) result(output) | ||
| 84 | class(array_type), intent(in), target :: input | ||
| 85 | integer, intent(in) :: dim | ||
| 86 | type(array_type), pointer :: output | ||
| 87 | end function softmax_array | ||
| 88 | end interface | ||
| 89 | |||
| 90 | interface swish | ||
| 91 | module function swish_array(input, beta) result(output) | ||
| 92 | class(array_type), intent(in), target :: input | ||
| 93 | real(real32), intent(in) :: beta | ||
| 94 | type(array_type), pointer :: output | ||
| 95 | end function swish_array | ||
| 96 | end interface | ||
| 97 | !------------------------------------------------------------------------------- | ||
| 98 | |||
| 99 | |||
| 100 | !------------------------------------------------------------------------------- | ||
| 101 | ! Loss functions | ||
| 102 | !------------------------------------------------------------------------------- | ||
| 103 | interface huber | ||
| 104 | module function huber_array(delta, gamma) result( output ) | ||
| 105 | class(array_type), intent(in), target :: delta | ||
| 106 | real(real32), intent(in) :: gamma | ||
| 107 | type(array_type), pointer :: output | ||
| 108 | end function huber_array | ||
| 109 | end interface | ||
| 110 | !------------------------------------------------------------------------------- | ||
| 111 | |||
| 112 | |||
| 113 | !------------------------------------------------------------------------------- | ||
| 114 | ! Layer operations | ||
| 115 | !------------------------------------------------------------------------------- | ||
| 116 | interface | ||
| 117 | module function avgpool1d(input, pool_size, stride) result(output) | ||
| 118 | type(array_type), intent(in), target :: input | ||
| 119 | integer, intent(in) :: pool_size | ||
| 120 | integer, intent(in) :: stride | ||
| 121 | type(array_type), pointer :: output | ||
| 122 | end function avgpool1d | ||
| 123 | |||
| 124 | module function avgpool2d(input, pool_size, stride) result(output) | ||
| 125 | type(array_type), intent(in), target :: input | ||
| 126 | integer, dimension(2), intent(in) :: pool_size | ||
| 127 | integer, dimension(2), intent(in) :: stride | ||
| 128 | type(array_type), pointer :: output | ||
| 129 | end function avgpool2d | ||
| 130 | |||
| 131 | module function avgpool3d(input, pool_size, stride) result(output) | ||
| 132 | type(array_type), intent(in), target :: input | ||
| 133 | integer, dimension(3), intent(in) :: pool_size | ||
| 134 | integer, dimension(3), intent(in) :: stride | ||
| 135 | type(array_type), pointer :: output | ||
| 136 | end function avgpool3d | ||
| 137 | end interface | ||
| 138 | |||
| 139 | interface | ||
| 140 | module function maxpool1d(input, pool_size, stride) result(output) | ||
| 141 | type(array_type), intent(in), target :: input | ||
| 142 | integer, intent(in) :: pool_size | ||
| 143 | integer, intent(in) :: stride | ||
| 144 | type(array_type), pointer :: output | ||
| 145 | end function maxpool1d | ||
| 146 | |||
| 147 | module function maxpool2d(input, pool_size, stride) result(output) | ||
| 148 | type(array_type), intent(in), target :: input | ||
| 149 | integer, dimension(2), intent(in) :: pool_size | ||
| 150 | integer, dimension(2), intent(in) :: stride | ||
| 151 | type(array_type), pointer :: output | ||
| 152 | end function maxpool2d | ||
| 153 | |||
| 154 | module function maxpool3d(input, pool_size, stride) result(output) | ||
| 155 | type(array_type), intent(in), target :: input | ||
| 156 | integer, dimension(3), intent(in) :: pool_size | ||
| 157 | integer, dimension(3), intent(in) :: stride | ||
| 158 | type(array_type), pointer :: output | ||
| 159 | end function maxpool3d | ||
| 160 | end interface | ||
| 161 | |||
| 162 | interface | ||
| 163 | module function pad1d(input, facets, pad_size, imethod) result(output) | ||
| 164 | type(array_type), intent(in), target :: input | ||
| 165 | type(facets_type), intent(in) :: facets | ||
| 166 | integer, intent(in) :: pad_size | ||
| 167 | integer, intent(in) :: imethod | ||
| 168 | type(array_type), pointer :: output | ||
| 169 | end function pad1d | ||
| 170 | |||
| 171 | module function pad2d(input, facets, pad_size, imethod) result(output) | ||
| 172 | type(array_type), intent(in), target :: input | ||
| 173 | type(facets_type), dimension(2), intent(in) :: facets | ||
| 174 | integer, dimension(2), intent(in) :: pad_size | ||
| 175 | integer, intent(in) :: imethod | ||
| 176 | type(array_type), pointer :: output | ||
| 177 | end function pad2d | ||
| 178 | |||
| 179 | module function pad3d(input, facets, pad_size, imethod) result(output) | ||
| 180 | type(array_type), intent(in), target :: input | ||
| 181 | type(facets_type), dimension(3), intent(in) :: facets | ||
| 182 | integer, dimension(3), intent(in) :: pad_size | ||
| 183 | integer, intent(in) :: imethod | ||
| 184 | type(array_type), pointer :: output | ||
| 185 | end function pad3d | ||
| 186 | end interface | ||
| 187 | |||
| 188 | interface merge_over_channels | ||
| 189 | module function merge_scalar_over_channels(tsource, fsource, mask) result(output) | ||
| 190 | class(array_type), intent(in), target :: tsource | ||
| 191 | real(real32), intent(in) :: fsource | ||
| 192 | logical, dimension(:,:), intent(in) :: mask | ||
| 193 | type(array_type), pointer :: output | ||
| 194 | end function merge_scalar_over_channels | ||
| 195 | end interface | ||
| 196 | |||
| 197 | interface | ||
| 198 | module function batchnorm( & | ||
| 199 | input, params, momentum, mean, variance, epsilon & | ||
| 200 | ) result( output ) | ||
| 201 | class(array_type), intent(in), target :: input | ||
| 202 | class(array_type), intent(in), target :: params | ||
| 203 | real(real32), intent(in) :: momentum | ||
| 204 | real(real32), dimension(:), intent(in) :: mean | ||
| 205 | real(real32), dimension(:), intent(in) :: variance | ||
| 206 | real(real32), intent(in) :: epsilon | ||
| 207 | type(batchnorm_array_type), pointer :: output | ||
| 208 | end function batchnorm | ||
| 209 | |||
| 210 | module function batchnorm_inference( & | ||
| 211 | input, params, mean, variance, epsilon & | ||
| 212 | ) result( output ) | ||
| 213 | class(array_type), intent(in), target :: input | ||
| 214 | class(array_type), intent(in), target :: params | ||
| 215 | real(real32), dimension(:), intent(in) :: mean | ||
| 216 | real(real32), dimension(:), intent(in) :: variance | ||
| 217 | real(real32), intent(in) :: epsilon | ||
| 218 | type(batchnorm_array_type), pointer :: output | ||
| 219 | end function batchnorm_inference | ||
| 220 | end interface | ||
| 221 | |||
| 222 | interface | ||
| 223 | module function conv1d(input, kernel, stride, dilation) result(output) | ||
| 224 | type(array_type), intent(in), target :: input | ||
| 225 | type(array_type), intent(in), target :: kernel | ||
| 226 | integer, intent(in) :: stride | ||
| 227 | integer, intent(in) :: dilation | ||
| 228 | type(array_type), pointer :: output | ||
| 229 | end function conv1d | ||
| 230 | |||
| 231 | module function conv2d(input, kernel, stride, dilation) result(output) | ||
| 232 | type(array_type), intent(in), target :: input | ||
| 233 | type(array_type), intent(in), target :: kernel | ||
| 234 | integer, dimension(2), intent(in) :: stride | ||
| 235 | integer, dimension(2), intent(in) :: dilation | ||
| 236 | type(array_type), pointer :: output | ||
| 237 | end function conv2d | ||
| 238 | |||
| 239 | module function conv3d(input, kernel, stride, dilation) result(output) | ||
| 240 | type(array_type), intent(in), target :: input | ||
| 241 | type(array_type), intent(in), target :: kernel | ||
| 242 | integer, dimension(3), intent(in) :: stride | ||
| 243 | integer, dimension(3), intent(in) :: dilation | ||
| 244 | type(array_type), pointer :: output | ||
| 245 | end function conv3d | ||
| 246 | end interface | ||
| 247 | |||
| 248 | interface | ||
| 249 | module function kipf_propagate(vertex_features, adj_ia, adj_ja) result(c) | ||
| 250 | !! Propagate values from one autodiff array to another | ||
| 251 | class(array_type), intent(in), target :: vertex_features | ||
| 252 | integer, dimension(:), intent(in) :: adj_ia | ||
| 253 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 254 | type(array_type), pointer :: c | ||
| 255 | end function kipf_propagate | ||
| 256 | |||
| 257 | module function kipf_update(a, weight, adj_ia) result(c) | ||
| 258 | !! Update the message passing layer | ||
| 259 | class(array_type), intent(in), target :: a | ||
| 260 | class(array_type), intent(in), target :: weight | ||
| 261 | integer, dimension(:), intent(in) :: adj_ia | ||
| 262 | type(array_type), pointer :: c | ||
| 263 | end function kipf_update | ||
| 264 | end interface | ||
| 265 | |||
| 266 | interface | ||
| 267 | module function duvenaud_propagate( & | ||
| 268 | vertex_features, edge_features, adj_ia, adj_ja & | ||
| 269 | ) result(c) | ||
| 270 | !! Duvenaud message passing function | ||
| 271 | class(array_type), intent(in), target :: vertex_features | ||
| 272 | class(array_type), intent(in), target :: edge_features | ||
| 273 | integer, dimension(:), intent(in) :: adj_ia | ||
| 274 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 275 | type(array_type), pointer :: c | ||
| 276 | end function duvenaud_propagate | ||
| 277 | |||
| 278 | module function duvenaud_update( & | ||
| 279 | a, weight, adj_ia, min_degree, max_degree & | ||
| 280 | ) result(c) | ||
| 281 | !! Duvenaud update function | ||
| 282 | class(array_type), intent(in), target :: a | ||
| 283 | class(array_type), intent(in), target :: weight | ||
| 284 | integer, dimension(:), intent(in) :: adj_ia | ||
| 285 | integer, intent(in) :: min_degree, max_degree | ||
| 286 | type(array_type), pointer :: c | ||
| 287 | end function duvenaud_update | ||
| 288 | end interface | ||
| 289 | !------------------------------------------------------------------------------- | ||
| 290 | |||
| 291 | − | end module athena__diffstruc_extd | |
| 292 |