| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | !!!############################################################################# | ||
| 2 | !!! Code written by Ned Thaddeus Taylor | ||
| 3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
| 4 | !!!############################################################################# | ||
| 5 | !!! module contains the network class, which is used to define a neural network | ||
| 6 | !!! module contains the following derived types: | ||
| 7 | !!! - network_type | ||
| 8 | !!!################## | ||
| 9 | !!! network_type contains the following procedures: | ||
| 10 | !!! - print - print the network to file | ||
| 11 | !!! - read - read the network from a file | ||
| 12 | !!! - add - add a layer to the network | ||
| 13 | !!! - reset - reset the network | ||
| 14 | !!! - compile - compile the network | ||
| 15 | !!! - set_batch_size - set batch size | ||
| 16 | !!! - set_metrics - set network metrics | ||
| 17 | !!! - set_loss - set network loss method | ||
| 18 | !!! - train - train the network | ||
| 19 | !!! - test - test the network | ||
| 20 | !!! - predict - return predicted results from supplied inputs using ... | ||
| 21 | !!! ... the trained network | ||
| 22 | !!! - update - update the learnable parameters of the network based ... | ||
| 23 | !!! ... on gradients | ||
| 24 | !!! - reduce - reduce two networks down to one ... | ||
| 25 | !!! ... (i.e. add two networks - parallel) | ||
| 26 | !!! - copy - copy a network | ||
| 27 | !!! - get_num_params - get number of learnable parameters in the network | ||
| 28 | !!! - get_params - get learnable parameters | ||
| 29 | !!! - set_params - set learnable parameters | ||
| 30 | !!! - get_gradients - get gradients of learnable parameters | ||
| 31 | !!! - set_gradients - set learnable parameter gradients | ||
| 32 | !!! - reset_gradients - reset learnable parameter gradients | ||
| 33 | !!! - forward - forward pass | ||
| 34 | !!! - backward - backward pass | ||
| 35 | !!!############################################################################# | ||
| 36 | module network | ||
| 37 | use constants, only: real12 | ||
| 38 | use metrics, only: metric_dict_type | ||
| 39 | use optimiser, only: base_optimiser_type | ||
| 40 | use loss, only: & | ||
| 41 | comp_loss_func => compute_loss_function, & | ||
| 42 | comp_loss_deriv => compute_loss_derivative | ||
| 43 | use base_layer, only: base_layer_type | ||
| 44 | use container_layer, only: container_layer_type | ||
| 45 | implicit none | ||
| 46 | |||
| 47 | private | ||
| 48 | |||
| 49 | public :: network_type | ||
| 50 | |||
| 51 | |||
| 52 | type :: network_type | ||
| 53 | real(real12) :: accuracy, loss | ||
| 54 | integer :: batch_size = 0 | ||
| 55 | integer :: num_layers = 0 | ||
| 56 | integer :: num_outputs = 0 | ||
| 57 | class(base_optimiser_type), allocatable :: optimiser | ||
| 58 | type(metric_dict_type), dimension(2) :: metrics | ||
| 59 | type(container_layer_type), allocatable, dimension(:) :: model | ||
| 60 | procedure(comp_loss_func), nopass, pointer :: get_loss => null() | ||
| 61 | procedure(comp_loss_deriv), nopass, pointer :: get_loss_deriv => null() | ||
| 62 | contains | ||
| 63 | procedure, pass(this) :: print | ||
| 64 | procedure, pass(this) :: read | ||
| 65 | procedure, pass(this) :: add | ||
| 66 | procedure, pass(this) :: reset | ||
| 67 | procedure, pass(this) :: compile | ||
| 68 | procedure, pass(this) :: set_batch_size | ||
| 69 | procedure, pass(this) :: set_metrics | ||
| 70 | procedure, pass(this) :: set_loss | ||
| 71 | procedure, pass(this) :: train | ||
| 72 | procedure, pass(this) :: test | ||
| 73 | procedure, pass(this) :: predict => predict_1d | ||
| 74 | procedure, pass(this) :: update | ||
| 75 | |||
| 76 | procedure, pass(this) :: reduce => network_reduction | ||
| 77 | procedure, pass(this) :: copy => network_copy | ||
| 78 | |||
| 79 | procedure, pass(this) :: get_num_params | ||
| 80 | procedure, pass(this) :: get_params | ||
| 81 | procedure, pass(this) :: set_params | ||
| 82 | procedure, pass(this) :: get_gradients | ||
| 83 | procedure, pass(this) :: set_gradients | ||
| 84 | procedure, pass(this) :: reset_gradients | ||
| 85 | |||
| 86 | procedure, pass(this) :: forward => forward_1d | ||
| 87 | procedure, pass(this) :: backward => backward_1d | ||
| 88 | end type network_type | ||
| 89 | |||
| 90 | interface network_type | ||
| 91 | !!------------------------------------------------------------------------- | ||
| 92 | !! setup the network (network initialisation) | ||
| 93 | !!------------------------------------------------------------------------- | ||
| 94 | !! layers = (T, in) layer container | ||
| 95 | !! optimiser = (T, in, opt) optimiser | ||
| 96 | !! loss_method = (S, in, opt) loss method | ||
| 97 | !! metrics = (*, in, opt) metrics, either string or metric_dict_type | ||
| 98 | !! batch_size = (I, in, opt) batch size | ||
| 99 | module function network_setup( & | ||
| 100 | layers, & | ||
| 101 | optimiser, loss_method, metrics, batch_size) result(network) | ||
| 102 | type(container_layer_type), dimension(:), intent(in) :: layers | ||
| 103 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 104 | character(*), optional, intent(in) :: loss_method | ||
| 105 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 106 | integer, optional, intent(in) :: batch_size | ||
| 107 | type(network_type) :: network | ||
| 108 | end function network_setup | ||
| 109 | end interface network_type | ||
| 110 | |||
| 111 | interface | ||
| 112 | !!------------------------------------------------------------------------- | ||
| 113 | !! print the network to file | ||
| 114 | !!------------------------------------------------------------------------- | ||
| 115 | !! this = (T, in) network type | ||
| 116 | !! file = (I, in) file name | ||
| 117 | module subroutine print(this, file) | ||
| 118 | class(network_type), intent(in) :: this | ||
| 119 | character(*), intent(in) :: file | ||
| 120 | end subroutine print | ||
| 121 | |||
| 122 | !!------------------------------------------------------------------------- | ||
| 123 | !! read the network from a file | ||
| 124 | !!------------------------------------------------------------------------- | ||
| 125 | !! this = (T, io) network type | ||
| 126 | !! file = (I, in) file name | ||
| 127 | module subroutine read(this, file) | ||
| 128 | class(network_type), intent(inout) :: this | ||
| 129 | character(*), intent(in) :: file | ||
| 130 | end subroutine read | ||
| 131 | |||
| 132 | !!------------------------------------------------------------------------- | ||
| 133 | !! add a layer to the network | ||
| 134 | !!------------------------------------------------------------------------- | ||
| 135 | !! this = (T, io) network type | ||
| 136 | !! layer = (I, in) layer to add | ||
| 137 | module subroutine add(this, layer) | ||
| 138 | class(network_type), intent(inout) :: this | ||
| 139 | class(base_layer_type), intent(in) :: layer | ||
| 140 | end subroutine add | ||
| 141 | |||
| 142 | !!------------------------------------------------------------------------- | ||
| 143 | !! reset the network | ||
| 144 | !!------------------------------------------------------------------------- | ||
| 145 | !! this = (T, io) network type | ||
| 146 | module subroutine reset(this) | ||
| 147 | class(network_type), intent(inout) :: this | ||
| 148 | end subroutine reset | ||
| 149 | |||
| 150 | !!------------------------------------------------------------------------- | ||
| 151 | !! compile the network | ||
| 152 | !!------------------------------------------------------------------------- | ||
| 153 | !! this = (T, io) network type | ||
| 154 | !! optimiser = (T, in) optimiser | ||
| 155 | !! loss_method = (S, in, opt) loss method | ||
| 156 | !! metrics = (*, in, opt) metrics, either string or metric_dict_type | ||
| 157 | !! batch_size = (I, in, opt) batch size | ||
| 158 | !! verbose = (I, in, opt) verbosity level | ||
| 159 | module subroutine compile(this, optimiser, loss_method, metrics, & | ||
| 160 | batch_size, verbose) | ||
| 161 | class(network_type), intent(inout) :: this | ||
| 162 | class(base_optimiser_type), intent(in) :: optimiser | ||
| 163 | character(*), optional, intent(in) :: loss_method | ||
| 164 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 165 | integer, optional, intent(in) :: batch_size | ||
| 166 | integer, optional, intent(in) :: verbose | ||
| 167 | end subroutine compile | ||
| 168 | |||
| 169 | !!------------------------------------------------------------------------- | ||
| 170 | !! set batch size | ||
| 171 | !!------------------------------------------------------------------------- | ||
| 172 | !! this = (T, io) network type | ||
| 173 | !! batch_size = (I, in) batch size to use | ||
| 174 | module subroutine set_batch_size(this, batch_size) | ||
| 175 | class(network_type), intent(inout) :: this | ||
| 176 | integer, intent(in) :: batch_size | ||
| 177 | end subroutine set_batch_size | ||
| 178 | |||
| 179 | !!------------------------------------------------------------------------- | ||
| 180 | !! set network metrics | ||
| 181 | !!------------------------------------------------------------------------- | ||
| 182 | !! this = (T, io) network type | ||
| 183 | !! metrics = (*, in) metrics to use | ||
| 184 | module subroutine set_metrics(this, metrics) | ||
| 185 | class(network_type), intent(inout) :: this | ||
| 186 | class(*), dimension(..), intent(in) :: metrics | ||
| 187 | end subroutine set_metrics | ||
| 188 | |||
| 189 | !!------------------------------------------------------------------------- | ||
| 190 | !! set network loss method | ||
| 191 | !!------------------------------------------------------------------------- | ||
| 192 | !! this = (T, io) network type | ||
| 193 | !! loss_method = (S, in) loss method to use | ||
| 194 | !! verbose = (I, in, opt) verbosity level | ||
| 195 | module subroutine set_loss(this, loss_method, verbose) | ||
| 196 | class(network_type), intent(inout) :: this | ||
| 197 | character(*), intent(in) :: loss_method | ||
| 198 | integer, optional, intent(in) :: verbose | ||
| 199 | end subroutine set_loss | ||
| 200 | |||
| 201 | !!------------------------------------------------------------------------- | ||
| 202 | !! train the network | ||
| 203 | !!------------------------------------------------------------------------- | ||
| 204 | !! this = (T, io) network type | ||
| 205 | !! input = (R, in) input data | ||
| 206 | !! output = (*, in) expected output data (data labels) | ||
| 207 | !! num_epochs = (I, in) number of epochs to train for | ||
| 208 | !! batch_size = (I, in, opt) batch size (DEPRECATED) | ||
| 209 | !! addit_input = (R, in, opt) additional input data | ||
| 210 | !! addit_layer = (I, in, opt) layer to insert additional input data | ||
| 211 | !! plateau_threshold = (R, in, opt) threshold for checking learning plateau | ||
| 212 | !! shuffle_batches = (B, in, opt) shuffle batch order | ||
| 213 | !! batch_print_step = (I, in, opt) print step for batch | ||
| 214 | !! verbose = (I, in, opt) verbosity level | ||
| 215 | module subroutine train(this, input, output, num_epochs, batch_size, & | ||
| 216 | addit_input, addit_layer, & | ||
| 217 | plateau_threshold, shuffle_batches, batch_print_step, verbose) | ||
| 218 | class(network_type), intent(inout) :: this | ||
| 219 | real(real12), dimension(..), intent(in) :: input | ||
| 220 | class(*), dimension(:,:), intent(in) :: output | ||
| 221 | integer, intent(in) :: num_epochs | ||
| 222 | integer, optional, intent(in) :: batch_size !! deprecated | ||
| 223 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 224 | integer, optional, intent(in) :: addit_layer | ||
| 225 | real(real12), optional, intent(in) :: plateau_threshold | ||
| 226 | logical, optional, intent(in) :: shuffle_batches | ||
| 227 | integer, optional, intent(in) :: batch_print_step | ||
| 228 | integer, optional, intent(in) :: verbose | ||
| 229 | end subroutine train | ||
| 230 | |||
| 231 | !!------------------------------------------------------------------------- | ||
| 232 | !! test the network | ||
| 233 | !!------------------------------------------------------------------------- | ||
| 234 | !! this = (T, io) network type | ||
| 235 | !! input = (R, in) input data | ||
| 236 | !! output = (*, in) expected output data (data labels) | ||
| 237 | !! addit_input = (R, in, opt) additional input data | ||
| 238 | !! addit_layer = (I, in, opt) layer to insert additional input data | ||
| 239 | !! verbose = (I, in, opt) verbosity level | ||
| 240 | module subroutine test(this, input, output, & | ||
| 241 | addit_input, addit_layer, & | ||
| 242 | verbose) | ||
| 243 | class(network_type), intent(inout) :: this | ||
| 244 | real(real12), dimension(..), intent(in) :: input | ||
| 245 | class(*), dimension(:,:), intent(in) :: output | ||
| 246 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 247 | integer, optional, intent(in) :: addit_layer | ||
| 248 | integer, optional, intent(in) :: verbose | ||
| 249 | end subroutine test | ||
| 250 | |||
| 251 | !!------------------------------------------------------------------------- | ||
| 252 | !! return predicted results from supplied inputs using the trained network | ||
| 253 | !!------------------------------------------------------------------------- | ||
| 254 | !! this = (T, in) network type | ||
| 255 | !! input = (R, in) input data | ||
| 256 | !! addit_input = (R, in, opt) additional input data | ||
| 257 | !! addit_layer = (I, in, opt) layer to insert additional input data | ||
| 258 | !! verbose = (I, in, opt) verbosity level | ||
| 259 | !! output = (R, out) predicted output data | ||
| 260 | module function predict_1d(this, input, & | ||
| 261 | addit_input, addit_layer, & | ||
| 262 | verbose) result(output) | ||
| 263 | class(network_type), intent(inout) :: this | ||
| 264 | real(real12), dimension(..), intent(in) :: input | ||
| 265 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 266 | integer, optional, intent(in) :: addit_layer | ||
| 267 | integer, optional, intent(in) :: verbose | ||
| 268 | real(real12), dimension(:,:), allocatable :: output | ||
| 269 | end function predict_1d | ||
| 270 | |||
| 271 | !!------------------------------------------------------------------------- | ||
| 272 | !! update the learnable parameters of the network based on gradients | ||
| 273 | !!------------------------------------------------------------------------- | ||
| 274 | !! this = (T, io) network type | ||
| 275 | module subroutine update(this) | ||
| 276 | class(network_type), intent(inout) :: this | ||
| 277 | end subroutine update | ||
| 278 | |||
| 279 | !!------------------------------------------------------------------------- | ||
| 280 | !! reduce two networks down to one (i.e. add two networks - parallel) | ||
| 281 | !!------------------------------------------------------------------------- | ||
| 282 | !! this = (T, io) network type, resultant network of the reduction | ||
| 283 | !! source = (T, in) network type | ||
| 284 | module subroutine network_reduction(this, source) | ||
| 285 | class(network_type), intent(inout) :: this | ||
| 286 | type(network_type), intent(in) :: source | ||
| 287 | end subroutine network_reduction | ||
| 288 | |||
| 289 | !!------------------------------------------------------------------------- | ||
| 290 | !! copy a network | ||
| 291 | !!------------------------------------------------------------------------- | ||
| 292 | !! this = (T, io) network type, resultant network of the copy | ||
| 293 | !! source = (T, in) network type | ||
| 294 | module subroutine network_copy(this, source) | ||
| 295 | class(network_type), intent(inout) :: this | ||
| 296 | type(network_type), intent(in) :: source | ||
| 297 | end subroutine network_copy | ||
| 298 | |||
| 299 | !!------------------------------------------------------------------------- | ||
| 300 | !! get number of learnable parameters in the network | ||
| 301 | !!------------------------------------------------------------------------- | ||
| 302 | !! this = (T, in) network type | ||
| 303 | !! num_params = (I, out) number of parameters | ||
| 304 | pure module function get_num_params(this) result(num_params) | ||
| 305 | class(network_type), intent(in) :: this | ||
| 306 | integer :: num_params | ||
| 307 | end function get_num_params | ||
| 308 | |||
| 309 | !!------------------------------------------------------------------------- | ||
| 310 | !! get learnable parameters | ||
| 311 | !!------------------------------------------------------------------------- | ||
| 312 | !! this = (T, in) network type | ||
| 313 | !! params = (R, out) learnable parameters | ||
| 314 | pure module function get_params(this) result(params) | ||
| 315 | class(network_type), intent(in) :: this | ||
| 316 | real(real12), allocatable, dimension(:) :: params | ||
| 317 | end function get_params | ||
| 318 | |||
| 319 | !!------------------------------------------------------------------------- | ||
| 320 | !! set learnable parameters | ||
| 321 | !!------------------------------------------------------------------------- | ||
| 322 | !! this = (T, io) network type | ||
| 323 | !! params = (R, in) learnable parameters | ||
| 324 | !! verbose = (I, in, opt) verbosity level | ||
| 325 | module subroutine set_params(this, params) | ||
| 326 | class(network_type), intent(inout) :: this | ||
| 327 | real(real12), dimension(:), intent(in) :: params | ||
| 328 | end subroutine set_params | ||
| 329 | |||
| 330 | !!------------------------------------------------------------------------- | ||
| 331 | !! get gradients of learnable parameters | ||
| 332 | !!------------------------------------------------------------------------- | ||
| 333 | !! this = (T, in) network type | ||
| 334 | !! gradients = (R, out) gradients | ||
| 335 | pure module function get_gradients(this) result(gradients) | ||
| 336 | class(network_type), intent(in) :: this | ||
| 337 | real(real12), allocatable, dimension(:) :: gradients | ||
| 338 | end function get_gradients | ||
| 339 | |||
| 340 | !!------------------------------------------------------------------------- | ||
| 341 | !! set learnable parameter gradients | ||
| 342 | !!------------------------------------------------------------------------- | ||
| 343 | !! this = (T, io) network type | ||
| 344 | !! gradients = (R, in) gradients | ||
| 345 | !! verbose = (I, in, opt) verbosity level | ||
| 346 | module subroutine set_gradients(this, gradients) | ||
| 347 | class(network_type), intent(inout) :: this | ||
| 348 | real(real12), dimension(..), intent(in) :: gradients | ||
| 349 | end subroutine set_gradients | ||
| 350 | |||
| 351 | !!------------------------------------------------------------------------- | ||
| 352 | !! reset learnable parameter gradients | ||
| 353 | !!------------------------------------------------------------------------- | ||
| 354 | !! this = (T, io) network type | ||
| 355 | !! verbose = (I, in, opt) verbosity level | ||
| 356 | !!------------------------------------------------------------------------- | ||
| 357 | module subroutine reset_gradients(this) | ||
| 358 | class(network_type), intent(inout) :: this | ||
| 359 | end subroutine reset_gradients | ||
| 360 | |||
| 361 | !!------------------------------------------------------------------------- | ||
| 362 | !! forward pass | ||
| 363 | !!------------------------------------------------------------------------- | ||
| 364 | !! this = (T, io) network type | ||
| 365 | !! input = (R, in) input data | ||
| 366 | !! addit_input = (R, in, opt) additional input data | ||
| 367 | !! layer = (I, in, opt) layer to insert additional input data | ||
| 368 | pure module subroutine forward_1d(this, input, addit_input, layer) | ||
| 369 | class(network_type), intent(inout) :: this | ||
| 370 | real(real12), dimension(..), intent(in) :: input | ||
| 371 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 372 | integer, optional, intent(in) :: layer | ||
| 373 | end subroutine forward_1d | ||
| 374 | |||
| 375 | !!------------------------------------------------------------------------- | ||
| 376 | !! backward pass | ||
| 377 | !!------------------------------------------------------------------------- | ||
| 378 | !! this = (T, io) network type | ||
| 379 | !! output = (R, in) output data | ||
| 380 | pure module subroutine backward_1d(this, output) | ||
| 381 | class(network_type), intent(inout) :: this | ||
| 382 | real(real12), dimension(:,:), intent(in) :: output | ||
| 383 | end subroutine backward_1d | ||
| 384 | end interface | ||
| 385 | |||
| 386 | ✗ | end module network | |
| 387 | !!!############################################################################# | ||
| 388 |