| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | !!!############################################################################# | ||
| 2 | !!! Code written by Ned Thaddeus Taylor | ||
| 3 | !!! Code part of the ATHENA library - a feedforward neural network library | ||
| 4 | !!!############################################################################# | ||
| 5 | !!! submodule of the network module | ||
| 6 | !!! submodule contains the associated methods from the network module | ||
| 7 | !!!############################################################################# | ||
| 8 | ✗ | submodule(network) network_submodule | |
| 9 | #ifdef _OPENMP | ||
| 10 | use omp_lib | ||
| 11 | #endif | ||
| 12 | use misc_ml, only: shuffle | ||
| 13 | |||
| 14 | ✗ | use accuracy, only: categorical_score, mae_score, mse_score, r2_score | |
| 15 | use base_layer, only: & | ||
| 16 | input_layer_type, flatten_layer_type, & | ||
| 17 | drop_layer_type, & | ||
| 18 | learnable_layer_type, & | ||
| 19 | batch_layer_type, & | ||
| 20 | conv_layer_type, & | ||
| 21 | pool_layer_type | ||
| 22 | #if defined(GFORTRAN) | ||
| 23 | use container_layer, only: container_reduction | ||
| 24 | #endif | ||
| 25 | |||
| 26 | !! input layer types | ||
| 27 | use input1d_layer, only: input1d_layer_type | ||
| 28 | use input2d_layer, only: input2d_layer_type | ||
| 29 | use input3d_layer, only: input3d_layer_type | ||
| 30 | use input4d_layer, only: input4d_layer_type | ||
| 31 | |||
| 32 | !! batch normalisation layer types | ||
| 33 | use batchnorm1d_layer, only: batchnorm1d_layer_type, read_batchnorm1d_layer | ||
| 34 | use batchnorm2d_layer, only: batchnorm2d_layer_type, read_batchnorm2d_layer | ||
| 35 | use batchnorm3d_layer, only: batchnorm3d_layer_type, read_batchnorm3d_layer | ||
| 36 | |||
| 37 | !! convolution layer types | ||
| 38 | use conv1d_layer, only: conv1d_layer_type, read_conv1d_layer | ||
| 39 | use conv2d_layer, only: conv2d_layer_type, read_conv2d_layer | ||
| 40 | use conv3d_layer, only: conv3d_layer_type, read_conv3d_layer | ||
| 41 | |||
| 42 | !! dropout layer types | ||
| 43 | use dropout_layer, only: dropout_layer_type, read_dropout_layer | ||
| 44 | use dropblock2d_layer, only: dropblock2d_layer_type, read_dropblock2d_layer | ||
| 45 | use dropblock3d_layer, only: dropblock3d_layer_type, read_dropblock3d_layer | ||
| 46 | |||
| 47 | !! pooling layer types | ||
| 48 | use avgpool1d_layer, only: avgpool1d_layer_type, read_avgpool1d_layer | ||
| 49 | use avgpool2d_layer, only: avgpool2d_layer_type, read_avgpool2d_layer | ||
| 50 | use avgpool3d_layer, only: avgpool3d_layer_type, read_avgpool3d_layer | ||
| 51 | use maxpool1d_layer, only: maxpool1d_layer_type, read_maxpool1d_layer | ||
| 52 | use maxpool2d_layer, only: maxpool2d_layer_type, read_maxpool2d_layer | ||
| 53 | use maxpool3d_layer, only: maxpool3d_layer_type, read_maxpool3d_layer | ||
| 54 | |||
| 55 | !! flatten layer types | ||
| 56 | use flatten1d_layer, only: flatten1d_layer_type | ||
| 57 | use flatten2d_layer, only: flatten2d_layer_type | ||
| 58 | use flatten3d_layer, only: flatten3d_layer_type | ||
| 59 | use flatten4d_layer, only: flatten4d_layer_type | ||
| 60 | |||
| 61 | !! fully connected (dense) layer types | ||
| 62 | use full_layer, only: full_layer_type, read_full_layer | ||
| 63 | |||
| 64 | implicit none | ||
| 65 | |||
| 66 | ! #ifdef _OPENMP | ||
| 67 | ! !$omp declare reduction(network_reduction:network_type:omp_out%network_reduction(omp_in)) & | ||
| 68 | ! !$omp& initializer(omp_priv = omp_orig) | ||
| 69 | ! #endif | ||
| 70 | |||
| 71 | contains | ||
| 72 | |||
| 73 | !!!############################################################################# | ||
| 74 | !!! network addition | ||
| 75 | !!!############################################################################# | ||
| 76 | 1 | module subroutine network_reduction(this, source) | |
| 77 | implicit none | ||
| 78 | class(network_type), intent(inout) :: this | ||
| 79 | type(network_type), intent(in) :: source | ||
| 80 | |||
| 81 | integer :: i | ||
| 82 | |||
| 83 | 1 | this%metrics(1)%val = this%metrics(1)%val + source%metrics(1)%val | |
| 84 | 1 | this%metrics(2)%val = this%metrics(2)%val + source%metrics(2)%val | |
| 85 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
|
4 | do i=1,size(this%model) |
| 86 | 1 | select type(layer_this => this%model(i)%layer) | |
| 87 | class is(learnable_layer_type) | ||
| 88 | ✗ | select type(layer_source => source%model(i)%layer) | |
| 89 | class is(learnable_layer_type) | ||
| 90 | 2 | call layer_this%merge(layer_source) | |
| 91 | end select | ||
| 92 | end select | ||
| 93 | end do | ||
| 94 | |||
| 95 | 1 | end subroutine network_reduction | |
| 96 | !!!############################################################################# | ||
| 97 | |||
| 98 | |||
| 99 | !!!############################################################################# | ||
| 100 | !!! network addition | ||
| 101 | !!!############################################################################# | ||
| 102 | 1 | module subroutine network_copy(this, source) | |
| 103 | implicit none | ||
| 104 | class(network_type), intent(inout) :: this | ||
| 105 | type(network_type), intent(in) :: source | ||
| 106 | |||
| 107 |
6/10✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
|
3 | this%metrics = source%metrics |
| 108 |
17/54✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 3 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✓ Branch 42 taken 3 times.
✓ Branch 43 taken 1 times.
✓ Branch 44 taken 3 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 3 times.
✗ Branch 47 not taken.
✓ Branch 49 taken 3 times.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✓ Branch 52 taken 3 times.
✗ Branch 53 not taken.
✗ Branch 54 not taken.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
|
8 | this%model = source%model |
| 109 | 1 | end subroutine network_copy | |
| 110 | !!!############################################################################# | ||
| 111 | |||
| 112 | |||
| 113 | !!!##########################################################################!!! | ||
| 114 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
| 115 | !!!##########################################################################!!! | ||
| 116 | |||
| 117 | |||
| 118 | !!!############################################################################# | ||
| 119 | !!! print network to file | ||
| 120 | !!!############################################################################# | ||
| 121 | ✗ | module subroutine print(this, file) | |
| 122 | implicit none | ||
| 123 | class(network_type), intent(in) :: this | ||
| 124 | character(*), intent(in) :: file | ||
| 125 | |||
| 126 | integer :: l, unit | ||
| 127 | |||
| 128 | ✗ | open(newunit=unit,file=file,status='replace') | |
| 129 | ✗ | close(unit) | |
| 130 | |||
| 131 | ✗ | do l=1,this%num_layers | |
| 132 | ✗ | call this%model(l)%layer%print(file) | |
| 133 | end do | ||
| 134 | |||
| 135 | ✗ | end subroutine print | |
| 136 | !!!############################################################################# | ||
| 137 | |||
| 138 | |||
| 139 | !!!############################################################################# | ||
| 140 | !!! read network from file | ||
| 141 | !!!############################################################################# | ||
| 142 | ✗ | module subroutine read(this, file) | |
| 143 | implicit none | ||
| 144 | class(network_type), intent(inout) :: this | ||
| 145 | character(*), intent(in) :: file | ||
| 146 | |||
| 147 | integer :: i, unit, stat | ||
| 148 | character(256) :: buffer | ||
| 149 | ✗ | open(newunit=unit,file=file,action='read') | |
| 150 | ✗ | i = 0 | |
| 151 | ✗ | card_loop: do | |
| 152 | ✗ | i = i + 1 | |
| 153 | ✗ | read(unit,'(A)',iostat=stat) buffer | |
| 154 | ✗ | if(stat.lt.0)then | |
| 155 | ✗ | exit card_loop | |
| 156 | ✗ | elseif(stat.gt.0)then | |
| 157 | ✗ | write(0,*) "ERROR: error encountered in network read" | |
| 158 | ✗ | stop "Exiting..." | |
| 159 | end if | ||
| 160 | ✗ | if(trim(adjustl(buffer)).eq."") cycle card_loop | |
| 161 | |||
| 162 | !! check if a tag line | ||
| 163 | ✗ | if(scan(buffer,'=').ne.0)then | |
| 164 | ✗ | write(0,*) "WARNING: unexpected line in read file" | |
| 165 | ✗ | write(0,*) trim(buffer) | |
| 166 | ✗ | write(0,*) " skipping..." | |
| 167 | ✗ | cycle card_loop | |
| 168 | end if | ||
| 169 | |||
| 170 | !! check for card | ||
| 171 | ✗ | select case(trim(adjustl(buffer))) | |
| 172 | case("BATCHNORM1D") | ||
| 173 | ✗ | call this%add(read_batchnorm1d_layer(unit)) | |
| 174 | case("BATCHNORM2D") | ||
| 175 | ✗ | call this%add(read_batchnorm2d_layer(unit)) | |
| 176 | case("BATCHNORM3D") | ||
| 177 | ✗ | call this%add(read_batchnorm3d_layer(unit)) | |
| 178 | case("CONV1D") | ||
| 179 | ✗ | call this%add(read_conv1d_layer(unit)) | |
| 180 | case("CONV2D") | ||
| 181 | ✗ | call this%add(read_conv2d_layer(unit)) | |
| 182 | case("CONV3D") | ||
| 183 | ✗ | call this%add(read_conv3d_layer(unit)) | |
| 184 | case("DROPOUT") | ||
| 185 | ✗ | call this%add(read_dropout_layer(unit)) | |
| 186 | case("DROPBLOCK2D") | ||
| 187 | ✗ | call this%add(read_dropblock2d_layer(unit)) | |
| 188 | case("DROPBLOCK3D") | ||
| 189 | ✗ | call this%add(read_dropblock3d_layer(unit)) | |
| 190 | case("AVGPOOL1D") | ||
| 191 | ✗ | call this%add(read_avgpool1d_layer(unit)) | |
| 192 | case("AVGPOOL2D") | ||
| 193 | ✗ | call this%add(read_avgpool2d_layer(unit)) | |
| 194 | case("AVGPOOL3D") | ||
| 195 | ✗ | call this%add(read_avgpool3d_layer(unit)) | |
| 196 | case("MAXPOOL1D") | ||
| 197 | ✗ | call this%add(read_maxpool1d_layer(unit)) | |
| 198 | case("MAXPOOL2D") | ||
| 199 | ✗ | call this%add(read_maxpool2d_layer(unit)) | |
| 200 | case("MAXPOOL3D") | ||
| 201 | ✗ | call this%add(read_maxpool3d_layer(unit)) | |
| 202 | case("FULL") | ||
| 203 | ✗ | call this%add(read_full_layer(unit)) | |
| 204 | case default | ||
| 205 | write(0,*) "ERROR: unrecognised card '"//& | ||
| 206 | ✗ | &trim(adjustl(buffer))//"'" | |
| 207 | ✗ | stop "Exiting..." | |
| 208 | end select | ||
| 209 | end do card_loop | ||
| 210 | ✗ | close(unit) | |
| 211 | |||
| 212 | ✗ | end subroutine read | |
| 213 | !!!############################################################################# | ||
| 214 | |||
| 215 | |||
| 216 | !!!##########################################################################!!! | ||
| 217 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
| 218 | !!!##########################################################################!!! | ||
| 219 | |||
| 220 | |||
| 221 | !!!############################################################################# | ||
| 222 | !!! append layer to network | ||
| 223 | !!!############################################################################# | ||
| 224 | 22 | module subroutine add(this, layer) | |
| 225 | implicit none | ||
| 226 | class(network_type), intent(inout) :: this | ||
| 227 | class(base_layer_type), intent(in) :: layer | ||
| 228 | |||
| 229 | character(4) :: name | ||
| 230 | |||
| 231 | select type(layer) | ||
| 232 | class is(input_layer_type) | ||
| 233 | 1 | name = "inpt" | |
| 234 | class is(batch_layer_type) | ||
| 235 | 1 | name = "batc" | |
| 236 | class is(conv_layer_type) | ||
| 237 | 7 | name = "conv" | |
| 238 | class is(flatten_layer_type) | ||
| 239 | 1 | name = "flat" | |
| 240 | class is(drop_layer_type) | ||
| 241 | 1 | name = "drop" | |
| 242 | class is(pool_layer_type) | ||
| 243 | 1 | name = "pool" | |
| 244 | type is(full_layer_type) | ||
| 245 | 10 | name = "full" | |
| 246 | class default | ||
| 247 | ✗ | name = "unkw" | |
| 248 | end select | ||
| 249 | |||
| 250 |
2/2✓ Branch 0 taken 9 times.
✓ Branch 1 taken 13 times.
|
22 | if(.not.allocated(this%model))then |
| 251 |
12/42✓ Branch 0 taken 9 times.
✓ Branch 1 taken 9 times.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 9 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 9 times.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✓ Branch 11 taken 9 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 9 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 9 times.
✓ Branch 16 taken 9 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✓ Branch 33 taken 9 times.
✓ Branch 34 taken 9 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 9 times.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
|
45 | this%model = [container_layer_type(name=name)] |
| 252 | 9 | this%num_layers = 1 | |
| 253 | else | ||
| 254 |
31/56✗ Branch 0 not taken.
✓ Branch 1 taken 13 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 13 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 13 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 13 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 13 times.
✓ Branch 15 taken 21 times.
✓ Branch 16 taken 13 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 13 times.
✓ Branch 19 taken 34 times.
✓ Branch 20 taken 13 times.
✓ Branch 21 taken 34 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 21 times.
✓ Branch 24 taken 13 times.
✓ Branch 26 taken 13 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 29 taken 13 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 13 times.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 13 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 21 times.
✓ Branch 39 taken 13 times.
✓ Branch 40 taken 21 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 21 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 10 times.
✓ Branch 45 taken 11 times.
✓ Branch 46 taken 10 times.
✓ Branch 47 taken 11 times.
✓ Branch 48 taken 13 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 34 times.
✓ Branch 51 taken 13 times.
✓ Branch 52 taken 34 times.
✓ Branch 53 taken 13 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 34 times.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
|
183 | this%model = [this%model(1:), container_layer_type(name=name)] |
| 255 | 13 | this%num_layers = this%num_layers + 1 | |
| 256 | end if | ||
| 257 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 22 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 22 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 22 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 22 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 22 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 22 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 22 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 22 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 22 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 22 times.
|
22 | allocate(this%model(size(this%model,dim=1))%layer, source=layer) |
| 258 | |||
| 259 | 22 | end subroutine add | |
| 260 | !!!############################################################################# | ||
| 261 | |||
| 262 | |||
| 263 | !!!############################################################################# | ||
| 264 | !!! set up network | ||
| 265 | !!!############################################################################# | ||
| 266 | 1 | module function network_setup( & | |
| 267 | 1 | layers, optimiser, loss_method, metrics, batch_size) result(network) | |
| 268 | implicit none | ||
| 269 | type(container_layer_type), dimension(:), intent(in) :: layers | ||
| 270 | class(base_optimiser_type), optional, intent(in) :: optimiser | ||
| 271 | character(*), optional, intent(in) :: loss_method | ||
| 272 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 273 | integer, optional, intent(in) :: batch_size | ||
| 274 | |||
| 275 | type(network_type) :: network | ||
| 276 | |||
| 277 | integer :: l | ||
| 278 | |||
| 279 | |||
| 280 | !!!----------------------------------------------------------------------------- | ||
| 281 | !!! handle optional arguments | ||
| 282 | !!!----------------------------------------------------------------------------- | ||
| 283 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(loss_method)) call network%set_loss(loss_method) |
| 284 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1 | if(present(metrics)) call network%set_metrics(metrics) |
| 285 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(batch_size)) network%batch_size = batch_size |
| 286 | |||
| 287 | |||
| 288 | !!!----------------------------------------------------------------------------- | ||
| 289 | !!! add layers to network | ||
| 290 | !!!----------------------------------------------------------------------------- | ||
| 291 |
5/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 9 taken 3 times.
✓ Branch 10 taken 1 times.
|
4 | do l = 1, size(layers) |
| 292 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
|
4 | call network%add(layers(l)%layer) |
| 293 | end do | ||
| 294 | |||
| 295 | |||
| 296 | !!!----------------------------------------------------------------------------- | ||
| 297 | !!! compile network if optimiser present | ||
| 298 | !!!----------------------------------------------------------------------------- | ||
| 299 |
1/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
|
1 | if(present(optimiser)) call network%compile(optimiser) |
| 300 | |||
| 301 |
5/8✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
|
4 | end function network_setup |
| 302 | !!!############################################################################# | ||
| 303 | |||
| 304 | |||
| 305 | !!!############################################################################# | ||
| 306 | !!! set network metrics | ||
| 307 | !!!############################################################################# | ||
| 308 | 12 | module subroutine set_metrics(this, metrics) | |
| 309 | use misc, only: to_lower | ||
| 310 | implicit none | ||
| 311 | class(network_type), intent(inout) :: this | ||
| 312 | class(*), dimension(..), intent(in) :: metrics | ||
| 313 | |||
| 314 | integer :: i | ||
| 315 | |||
| 316 | |||
| 317 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
|
36 | this%metrics%active = .false. |
| 318 | 12 | this%metrics(1)%key = "loss" | |
| 319 | 12 | this%metrics(2)%key = "accuracy" | |
| 320 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 12 times.
|
36 | this%metrics%threshold = 1.E-1_real12 |
| 321 | select rank(metrics) | ||
| 322 | #if defined(GFORTRAN) | ||
| 323 | rank(0) | ||
| 324 | select type(metrics) | ||
| 325 | type is(character(*)) | ||
| 326 | !! ERROR: ifort cannot identify that the rank of metrics has been ... | ||
| 327 | !! ... identified as scalar here | ||
| 328 | where(to_lower(trim(metrics)).eq.this%metrics%key) | ||
| 329 | this%metrics%active = .true. | ||
| 330 | end where | ||
| 331 | end select | ||
| 332 | #endif | ||
| 333 | rank(1) | ||
| 334 | 12 | select type(metrics) | |
| 335 | type is(character(*)) | ||
| 336 |
2/2✓ Branch 0 taken 11 times.
✓ Branch 1 taken 11 times.
|
22 | do i=1,size(metrics,1) |
| 337 |
7/10✗ Branch 0 not taken.
✓ Branch 1 taken 11 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 11 times.
✓ Branch 8 taken 11 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 22 times.
✓ Branch 11 taken 11 times.
✓ Branch 12 taken 11 times.
✓ Branch 13 taken 11 times.
|
44 | where(to_lower(trim(metrics(i))).eq.this%metrics%key) |
| 338 | this%metrics%active = .true. | ||
| 339 | end where | ||
| 340 | end do | ||
| 341 | type is(metric_dict_type) | ||
| 342 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(size(metrics,1).eq.2)then |
| 343 |
12/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✓ Branch 15 taken 2 times.
✓ Branch 16 taken 1 times.
✓ Branch 17 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 1 times.
✓ Branch 21 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✗ Branch 24 not taken.
|
3 | this%metrics(:2) = metrics(:2) |
| 344 | else | ||
| 345 | ✗ | stop "ERROR: invalid length array for metric_dict_type" | |
| 346 | end if | ||
| 347 | end select | ||
| 348 | rank default | ||
| 349 | ✗ | stop "ERROR: provided metrics rank in compile invalid" | |
| 350 | end select | ||
| 351 | |||
| 352 | 12 | end subroutine set_metrics | |
| 353 | !!!############################################################################# | ||
| 354 | |||
| 355 | |||
| 356 | !!!############################################################################# | ||
| 357 | !!! set network loss | ||
| 358 | !!!############################################################################# | ||
| 359 | 12 | module subroutine set_loss(this, loss_method, verbose) | |
| 360 | use misc, only: to_lower | ||
| 361 | use loss, only: & | ||
| 362 | compute_loss_bce, compute_loss_cce, & | ||
| 363 | compute_loss_mae, compute_loss_mse, & | ||
| 364 | compute_loss_nll | ||
| 365 | implicit none | ||
| 366 | class(network_type), intent(inout) :: this | ||
| 367 | character(*), intent(in) :: loss_method | ||
| 368 | integer, optional, intent(in) :: verbose | ||
| 369 | |||
| 370 | integer :: verbose_ | ||
| 371 | 12 | character(len=:), allocatable :: loss_method_ | |
| 372 | |||
| 373 | |||
| 374 | 12 | if(present(verbose))then | |
| 375 | 7 | verbose_ = verbose | |
| 376 | else | ||
| 377 | 5 | verbose_ = 0 | |
| 378 | end if | ||
| 379 | |||
| 380 | !!!----------------------------------------------------------------------------- | ||
| 381 | !!! handle analogous definitions | ||
| 382 | !!!----------------------------------------------------------------------------- | ||
| 383 |
3/8✓ Branch 1 taken 12 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 12 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 12 times.
|
12 | loss_method_ = to_lower(loss_method) |
| 384 | 1 | select case(loss_method) | |
| 385 | case("binary_crossentropy") | ||
| 386 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | loss_method_ = "bce" |
| 387 | case("categorical_crossentropy") | ||
| 388 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | loss_method_ = "cce" |
| 389 | case("mean_absolute_error") | ||
| 390 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | loss_method_ = "mae" |
| 391 | case("mean_squared_error") | ||
| 392 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | loss_method_ = "mse" |
| 393 | case("negative_log_likelihood") | ||
| 394 |
10/14✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✓ Branch 5 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
|
12 | loss_method_ = "nll" |
| 395 | end select | ||
| 396 | |||
| 397 | !!!----------------------------------------------------------------------------- | ||
| 398 | !!! set loss method | ||
| 399 | !!!----------------------------------------------------------------------------- | ||
| 400 | 1 | select case(loss_method_) | |
| 401 | case("bce") | ||
| 402 | 1 | this%get_loss => compute_loss_bce | |
| 403 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy" |
| 404 | case("cce") | ||
| 405 | 1 | this%get_loss => compute_loss_cce | |
| 406 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(verbose_.gt.0) write(*,*) "Loss method: Categorical Cross Entropy" |
| 407 | case("mae") | ||
| 408 | 1 | this%get_loss => compute_loss_mae | |
| 409 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(verbose_.gt.0) write(*,*) "Loss method: Mean Absolute Error" |
| 410 | case("mse") | ||
| 411 | 8 | this%get_loss => compute_loss_mse | |
| 412 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 times.
|
15 | if(verbose_.gt.0) write(*,*) "Loss method: Mean Squared Error" |
| 413 | case("nll") | ||
| 414 | 1 | this%get_loss => compute_loss_nll | |
| 415 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(verbose_.gt.0) write(*,*) "Loss method: Negative log likelihood" |
| 416 | case default | ||
| 417 | ✗ | write(0,*) "Failed loss method: "//trim(loss_method_) | |
| 418 |
5/6✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
|
12 | stop "ERROR: No loss method provided" |
| 419 | end select | ||
| 420 | 12 | this%get_loss_deriv => comp_loss_deriv | |
| 421 | |||
| 422 |
3/4✓ Branch 0 taken 7 times.
✓ Branch 1 taken 5 times.
✓ Branch 2 taken 12 times.
✗ Branch 3 not taken.
|
24 | end subroutine set_loss |
| 423 | !!!############################################################################# | ||
| 424 | |||
| 425 | |||
| 426 | !!!############################################################################# | ||
| 427 | !!! reset network | ||
| 428 | !!!############################################################################# | ||
| 429 | 4 | module subroutine reset(this) | |
| 430 | implicit none | ||
| 431 | class(network_type), intent(inout) :: this | ||
| 432 | |||
| 433 | 4 | this%accuracy = 0._real12 | |
| 434 | 4 | this%loss = huge(1._real12) | |
| 435 | 4 | this%batch_size = 0 | |
| 436 | 4 | this%num_layers = 0 | |
| 437 | 4 | this%num_outputs = 0 | |
| 438 |
4/6✓ Branch 0 taken 2 times.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
|
6 | if(allocated(this%optimiser)) deallocate(this%optimiser) |
| 439 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | call this%set_metrics(["loss"]) |
| 440 |
11/20✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✓ Branch 5 taken 4 times.
✓ Branch 6 taken 14 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 11 taken 14 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✓ Branch 14 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 14 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 4 times.
|
22 | if(allocated(this%model)) deallocate(this%model) |
| 441 | 4 | this%get_loss => null() | |
| 442 | 4 | this%get_loss_deriv => null() | |
| 443 | |||
| 444 | 4 | end subroutine reset | |
| 445 | !!!############################################################################# | ||
| 446 | |||
| 447 | |||
| 448 | !!!############################################################################# | ||
| 449 | !!! compile network | ||
| 450 | !!!############################################################################# | ||
| 451 | 7 | module subroutine compile(this, optimiser, loss_method, metrics, batch_size, & | |
| 452 | verbose) | ||
| 453 | implicit none | ||
| 454 | class(network_type), intent(inout) :: this | ||
| 455 | class(base_optimiser_type), intent(in) :: optimiser | ||
| 456 | character(*), optional, intent(in) :: loss_method | ||
| 457 | class(*), dimension(..), optional, intent(in) :: metrics | ||
| 458 | integer, optional, intent(in) :: batch_size | ||
| 459 | integer, optional, intent(in) :: verbose | ||
| 460 | |||
| 461 | integer :: i | ||
| 462 | integer :: verbose_ = 0, num_addit_inputs | ||
| 463 | 23 | class(base_layer_type), allocatable :: t_input_layer, t_flatten_layer | |
| 464 | |||
| 465 | |||
| 466 | !!!----------------------------------------------------------------------------- | ||
| 467 | !!! initialise optional arguments | ||
| 468 | !!!----------------------------------------------------------------------------- | ||
| 469 | 7 | if(present(verbose)) verbose_ = verbose | |
| 470 | |||
| 471 | |||
| 472 | !!!----------------------------------------------------------------------------- | ||
| 473 | !!! initialise metrics | ||
| 474 | !!!----------------------------------------------------------------------------- | ||
| 475 |
2/4✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7 times.
✗ Branch 3 not taken.
|
7 | if(present(metrics)) call this%set_metrics(metrics) |
| 476 | |||
| 477 | |||
| 478 | !!!----------------------------------------------------------------------------- | ||
| 479 | !!! initialise loss method | ||
| 480 | !!!----------------------------------------------------------------------------- | ||
| 481 |
1/2✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
|
7 | if(present(loss_method)) call this%set_loss(loss_method, verbose_) |
| 482 | |||
| 483 | |||
| 484 | !!!----------------------------------------------------------------------------- | ||
| 485 | !!! check for input layer | ||
| 486 | !!!----------------------------------------------------------------------------- | ||
| 487 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
|
7 | if(.not.allocated(this%model(1)%layer%input_shape))then |
| 488 | ✗ | stop "ERROR: input_shape of first layer not defined" | |
| 489 | end if | ||
| 490 | |||
| 491 | 14 | select type(first => this%model(1)%layer) | |
| 492 | class is(input_layer_type) | ||
| 493 | class default | ||
| 494 | this%model = [& | ||
| 495 | 28 | container_layer_type(name="inpt"),& | |
| 496 | this%model(1:)& | ||
| 497 |
31/56✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✓ Branch 15 taken 14 times.
✓ Branch 16 taken 7 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 7 times.
✓ Branch 19 taken 21 times.
✓ Branch 20 taken 7 times.
✓ Branch 21 taken 21 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 14 times.
✓ Branch 24 taken 7 times.
✓ Branch 26 taken 7 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 29 taken 7 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 7 times.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 7 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 14 times.
✓ Branch 39 taken 7 times.
✓ Branch 40 taken 14 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 14 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 7 times.
✓ Branch 45 taken 7 times.
✓ Branch 46 taken 7 times.
✓ Branch 47 taken 7 times.
✓ Branch 48 taken 7 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 21 times.
✓ Branch 51 taken 7 times.
✓ Branch 52 taken 21 times.
✓ Branch 53 taken 7 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 21 times.
✗ Branch 56 not taken.
✗ Branch 57 not taken.
✗ Branch 58 not taken.
✗ Branch 59 not taken.
✗ Branch 60 not taken.
✗ Branch 61 not taken.
|
112 | ] |
| 498 | 14 | associate(next => this%model(2)%layer) | |
| 499 | 10 | select case(size(next%input_shape,dim=1)) | |
| 500 | case(1) | ||
| 501 |
7/18✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 3 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 3 times.
|
3 | t_input_layer = input1d_layer_type(input_shape = next%input_shape) |
| 502 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3 times.
|
3 | allocate(this%model(1)%layer, source = t_input_layer) |
| 503 | case(2) | ||
| 504 | ✗ | select type(next) | |
| 505 | type is(conv1d_layer_type) | ||
| 506 | t_input_layer = input2d_layer_type(& | ||
| 507 | ✗ | input_shape = next%input_shape + & | |
| 508 | ✗ | [2*next%pad,0]) | |
| 509 | ✗ | allocate(this%model(1)%layer, source = t_input_layer) | |
| 510 | class default | ||
| 511 | t_input_layer = input2d_layer_type(& | ||
| 512 | ✗ | input_shape = next%input_shape) | |
| 513 | ✗ | allocate(this%model(1)%layer, source = t_input_layer) | |
| 514 | end select | ||
| 515 | case(3) | ||
| 516 | 2 | select type(next) | |
| 517 | type is(conv2d_layer_type) | ||
| 518 | t_input_layer = input3d_layer_type(& | ||
| 519 | ✗ | input_shape = next%input_shape + & | |
| 520 |
20/40✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✓ Branch 25 taken 4 times.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 29 taken 6 times.
✓ Branch 30 taken 2 times.
✓ Branch 32 taken 2 times.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 2 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 2 times.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✓ Branch 47 taken 2 times.
|
12 | [2*next%pad,0]) |
| 521 |
11/22✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
|
4 | allocate(this%model(1)%layer, source = t_input_layer) |
| 522 | class default | ||
| 523 | t_input_layer = input3d_layer_type(& | ||
| 524 | ✗ | input_shape = next%input_shape) | |
| 525 | ✗ | allocate(this%model(1)%layer, source = t_input_layer) | |
| 526 | end select | ||
| 527 | case(4) | ||
| 528 | 3 | select type(next) | |
| 529 | type is(conv3d_layer_type) | ||
| 530 | t_input_layer = input4d_layer_type(& | ||
| 531 | ✗ | input_shape = next%input_shape + & | |
| 532 |
20/40✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 2 times.
✓ Branch 25 taken 6 times.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 29 taken 8 times.
✓ Branch 30 taken 2 times.
✓ Branch 32 taken 2 times.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 2 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 2 times.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✓ Branch 47 taken 2 times.
|
16 | [2*next%pad,0]) |
| 533 |
11/22✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
|
4 | allocate(this%model(1)%layer, source = t_input_layer) |
| 534 | class default | ||
| 535 | t_input_layer = input4d_layer_type(& | ||
| 536 | ✗ | input_shape = next%input_shape) | |
| 537 | ✗ | allocate(this%model(1)%layer, source = t_input_layer) | |
| 538 | end select | ||
| 539 | end select | ||
| 540 |
7/13✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 2 times.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✓ Branch 12 taken 7 times.
✓ Branch 14 taken 7 times.
✗ Branch 15 not taken.
|
21 | deallocate(t_input_layer) |
| 541 | end associate | ||
| 542 | end select | ||
| 543 | |||
| 544 | |||
| 545 | !!!----------------------------------------------------------------------------- | ||
| 546 | !!! ignore calcuation of input gradients for 1st non-input layer | ||
| 547 | !!!----------------------------------------------------------------------------- | ||
| 548 | 14 | select type(second => this%model(2)%layer) | |
| 549 | class is(conv_layer_type) | ||
| 550 | 4 | second%calc_input_gradients = .false. | |
| 551 | end select | ||
| 552 | |||
| 553 | |||
| 554 | !!!----------------------------------------------------------------------------- | ||
| 555 | !!! initialise layers | ||
| 556 | !!!----------------------------------------------------------------------------- | ||
| 557 |
1/2✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
|
7 | if(verbose_.gt.0)then |
| 558 |
2/4✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
|
7 | write(*,*) "layer:",1, this%model(1)%name |
| 559 |
8/14✗ Branch 1 not taken.
✓ Branch 2 taken 7 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 7 times.
✓ Branch 19 taken 17 times.
✓ Branch 20 taken 7 times.
|
24 | write(*,*) this%model(1)%layer%input_shape |
| 560 |
8/14✗ Branch 1 not taken.
✓ Branch 2 taken 7 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 7 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 7 times.
✓ Branch 19 taken 17 times.
✓ Branch 20 taken 7 times.
|
24 | write(*,*) this%model(1)%layer%output_shape |
| 561 | end if | ||
| 562 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 7 times.
|
21 | do i=2,size(this%model,dim=1) |
| 563 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✓ Branch 6 taken 7 times.
✓ Branch 7 taken 7 times.
|
14 | if(.not.allocated(this%model(i)%layer%output_shape)) & |
| 564 | 28 | call this%model(i)%layer%init(this%model(i-1)%layer%output_shape, & | |
| 565 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
|
7 | this%batch_size) |
| 566 |
1/2✓ Branch 0 taken 14 times.
✗ Branch 1 not taken.
|
21 | if(verbose_.gt.0)then |
| 567 |
2/4✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
|
14 | write(*,*) "layer:",i, this%model(i)%name |
| 568 |
8/14✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 14 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 14 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 14 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 14 times.
✓ Branch 19 taken 29 times.
✓ Branch 20 taken 14 times.
|
43 | write(*,*) this%model(i)%layer%input_shape |
| 569 |
8/14✗ Branch 1 not taken.
✓ Branch 2 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 14 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 14 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 14 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 14 times.
✓ Branch 19 taken 29 times.
✓ Branch 20 taken 14 times.
|
43 | write(*,*) this%model(i)%layer%output_shape |
| 570 | end if | ||
| 571 | end do | ||
| 572 | |||
| 573 | |||
| 574 | !!!----------------------------------------------------------------------------- | ||
| 575 | !!! check for required reshape layers | ||
| 576 | !!!----------------------------------------------------------------------------- | ||
| 577 | 7 | i = 1 !! starting for layer 2 | |
| 578 | 14 | layer_loop: do | |
| 579 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 14 times.
|
21 | if(i.ge.size(this%model,dim=1)) exit layer_loop |
| 580 | 14 | i = i + 1 | |
| 581 | |||
| 582 |
2/2✓ Branch 0 taken 7 times.
✓ Branch 1 taken 7 times.
|
14 | flatten_layer_check: if(i.lt.size(this%model,dim=1))then |
| 583 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 7 times.
✗ Branch 7 not taken.
|
14 | if(allocated(this%model(i+1)%layer%input_shape).and.& |
| 584 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
|
7 | allocated(this%model(i)%layer%output_shape))then |
| 585 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 2 times.
✓ Branch 7 taken 5 times.
|
14 | if(size(this%model(i+1)%layer%input_shape).ne.& |
| 586 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
|
7 | size(this%model(i)%layer%output_shape))then |
| 587 | |||
| 588 | 4 | select type(current => this%model(i)%layer) | |
| 589 | class is(flatten_layer_type) | ||
| 590 | 2 | cycle layer_loop | |
| 591 | class default | ||
| 592 | 8 | this%model = [& | |
| 593 | this%model(1:i),& | ||
| 594 | 8 | container_layer_type(name="flat"),& | |
| 595 | this%model(i+1:size(this%model))& | ||
| 596 |
36/68✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✓ Branch 15 taken 4 times.
✓ Branch 16 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 2 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✓ Branch 32 taken 2 times.
✓ Branch 33 taken 2 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✓ Branch 36 taken 8 times.
✓ Branch 37 taken 2 times.
✓ Branch 38 taken 8 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✓ Branch 41 taken 2 times.
✓ Branch 43 taken 2 times.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✓ Branch 46 taken 2 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 2 times.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✗ Branch 51 not taken.
✗ Branch 52 not taken.
✓ Branch 53 taken 2 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 6 times.
✓ Branch 56 taken 2 times.
✓ Branch 57 taken 6 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 6 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 6 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 6 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 2 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 8 times.
✓ Branch 68 taken 2 times.
✓ Branch 69 taken 8 times.
✓ Branch 70 taken 2 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 8 times.
✗ Branch 73 not taken.
✗ Branch 74 not taken.
✗ Branch 75 not taken.
✗ Branch 76 not taken.
✗ Branch 77 not taken.
✗ Branch 78 not taken.
|
42 | ] |
| 597 | 2 | num_addit_inputs = 0 | |
| 598 | 4 | select type(next => this%model(i+1)%layer) | |
| 599 | type is(full_layer_type) | ||
| 600 | ✗ | num_addit_inputs = next%num_addit_inputs | |
| 601 | end select | ||
| 602 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
|
2 | select case(size(this%model(i)%layer%output_shape)) |
| 603 | case(2) | ||
| 604 | t_flatten_layer = flatten1d_layer_type(& | ||
| 605 | ✗ | input_shape = this%model(i)%layer%output_shape, & | |
| 606 | num_addit_outputs = num_addit_inputs, & | ||
| 607 | ✗ | batch_size = this%batch_size) | |
| 608 | ✗ | allocate(this%model(i+1)%layer, source=t_flatten_layer) | |
| 609 | case(3) | ||
| 610 | t_flatten_layer = flatten2d_layer_type(& | ||
| 611 | 2 | input_shape = this%model(i)%layer%output_shape, & | |
| 612 | num_addit_outputs = num_addit_inputs, & | ||
| 613 |
10/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
|
1 | batch_size = this%batch_size) |
| 614 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
|
1 | allocate(this%model(i+1)%layer, source=t_flatten_layer) |
| 615 | case(4) | ||
| 616 | t_flatten_layer = flatten3d_layer_type(& | ||
| 617 | 2 | input_shape = this%model(i)%layer%output_shape, & | |
| 618 | num_addit_outputs = num_addit_inputs, & | ||
| 619 |
10/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
|
1 | batch_size = this%batch_size) |
| 620 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
|
1 | allocate(this%model(i+1)%layer, source=t_flatten_layer) |
| 621 | case(5) | ||
| 622 | t_flatten_layer = flatten4d_layer_type(& | ||
| 623 | ✗ | input_shape = this%model(i)%layer%output_shape, & | |
| 624 | num_addit_outputs = num_addit_inputs, & | ||
| 625 | ✗ | batch_size = this%batch_size) | |
| 626 |
2/25✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
|
4 | allocate(this%model(i+1)%layer, source=t_flatten_layer) |
| 627 | end select | ||
| 628 | 2 | i = i + 1 | |
| 629 | 2 | cycle layer_loop | |
| 630 | end select | ||
| 631 | end if | ||
| 632 | else | ||
| 633 | |||
| 634 | end if | ||
| 635 | end if flatten_layer_check | ||
| 636 | |||
| 637 | end do layer_loop | ||
| 638 | |||
| 639 | !! update number of layers | ||
| 640 | !!-------------------------------------------------------------------------- | ||
| 641 | 7 | this%num_layers = i | |
| 642 | |||
| 643 | |||
| 644 | !! set number of outputs | ||
| 645 | !!-------------------------------------------------------------------------- | ||
| 646 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 18 taken 12 times.
✓ Branch 19 taken 7 times.
|
19 | this%num_outputs = product(this%model(this%num_layers)%layer%output_shape) |
| 647 | |||
| 648 | |||
| 649 | !!!----------------------------------------------------------------------------- | ||
| 650 | !!! initialise optimiser | ||
| 651 | !!!----------------------------------------------------------------------------- | ||
| 652 |
3/10✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✓ Branch 8 taken 7 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
|
7 | this%optimiser = optimiser |
| 653 | 7 | call this%optimiser%init(num_params=this%get_num_params()) | |
| 654 | |||
| 655 | |||
| 656 | !!!----------------------------------------------------------------------------- | ||
| 657 | !!! set batch size, if provided | ||
| 658 | !!!----------------------------------------------------------------------------- | ||
| 659 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 3 times.
|
7 | if(present(batch_size)) this%batch_size = batch_size |
| 660 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 3 times.
|
7 | if(this%batch_size.ne.0)then |
| 661 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
|
4 | if(this%model(1)%layer%batch_size.ne.0.and.& |
| 662 | this%model(1)%layer%batch_size.ne.this%batch_size)then | ||
| 663 | ✗ | write(*,*) "WARNING: batch_size in compile differs from batch_size of input layer" | |
| 664 | ✗ | write(*,*) " batch_size of input layer will be set to network batch_size" | |
| 665 | end if | ||
| 666 | 4 | call this%set_batch_size(this%batch_size) | |
| 667 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
|
3 | elseif(this%model(1)%layer%batch_size.ne.0)then |
| 668 | ✗ | call this%set_batch_size(this%model(1)%layer%batch_size) | |
| 669 | end if | ||
| 670 | |||
| 671 |
5/10✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 7 taken 2 times.
✓ Branch 8 taken 5 times.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
|
14 | end subroutine compile |
| 672 | !!!############################################################################# | ||
| 673 | |||
| 674 | |||
| 675 | !!!############################################################################# | ||
| 676 | !!! set batch size | ||
| 677 | !!!############################################################################# | ||
| 678 | 17 | module subroutine set_batch_size(this, batch_size) | |
| 679 | implicit none | ||
| 680 | class(network_type), intent(inout) :: this | ||
| 681 | integer, intent(in) :: batch_size | ||
| 682 | |||
| 683 | integer :: l | ||
| 684 | |||
| 685 | 17 | this%batch_size = batch_size | |
| 686 |
2/2✓ Branch 0 taken 43 times.
✓ Branch 1 taken 17 times.
|
60 | do l=1,this%num_layers |
| 687 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
|
60 | call this%model(l)%layer%set_batch_size(this%batch_size) |
| 688 | end do | ||
| 689 | |||
| 690 | 17 | end subroutine set_batch_size | |
| 691 | !!!############################################################################# | ||
| 692 | |||
| 693 | |||
| 694 | !!!##########################################################################!!! | ||
| 695 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
| 696 | !!!##########################################################################!!! | ||
| 697 | |||
| 698 | |||
| 699 | !!!############################################################################# | ||
| 700 | !!! return sample from any rank | ||
| 701 | !!!############################################################################# | ||
| 702 | 509 | pure function get_sample(input, start_index, end_index) result(output) | |
| 703 | implicit none | ||
| 704 | integer, intent(in) :: start_index, end_index | ||
| 705 | real(real12), dimension(..), intent(in) :: input | ||
| 706 | real(real12), allocatable, dimension(:,:) :: output | ||
| 707 | |||
| 708 | select rank(input) | ||
| 709 | rank(2) | ||
| 710 | ✗ | output = reshape(input(:,start_index:end_index), & | |
| 711 |
17/32✗ Branch 0 not taken.
✓ Branch 1 taken 509 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 509 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 509 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 509 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 509 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 509 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 509 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 509 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 509 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 509 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 509 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 509 times.
✓ Branch 28 taken 1018 times.
✓ Branch 29 taken 509 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 509 times.
✓ Branch 33 taken 509 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 509 times.
✗ Branch 36 not taken.
|
1527 | shape=[size(input(:,1)),end_index-start_index+1]) |
| 712 | rank(3) | ||
| 713 | ✗ | output = reshape(input(:,:,start_index:end_index), & | |
| 714 | ✗ | shape=[size(input(:,:,1)),end_index-start_index+1]) | |
| 715 | rank(4) | ||
| 716 | ✗ | output = reshape(input(:,:,:,start_index:end_index), & | |
| 717 | ✗ | shape=[size(input(:,:,:,1)),end_index-start_index+1]) | |
| 718 | rank(5) | ||
| 719 | ✗ | output = reshape(input(:,:,:,:,start_index:end_index), & | |
| 720 | ✗ | shape=[size(input(:,:,:,:,1)),end_index-start_index+1]) | |
| 721 | rank(6) | ||
| 722 | ✗ | output = reshape(input(:,:,:,:,:,start_index:end_index), & | |
| 723 | ✗ | shape=[size(input(:,:,:,:,:,1)),end_index-start_index+1]) | |
| 724 | end select | ||
| 725 | |||
| 726 | 509 | end function get_sample | |
| 727 | !!!############################################################################# | ||
| 728 | |||
| 729 | |||
| 730 | !!!############################################################################# | ||
| 731 | !!! get number of parameters | ||
| 732 | !!!############################################################################# | ||
| 733 | 1024 | pure module function get_num_params(this) result(num_params) | |
| 734 | implicit none | ||
| 735 | class(network_type), intent(in) :: this | ||
| 736 | integer :: num_params | ||
| 737 | |||
| 738 | integer :: l | ||
| 739 | |||
| 740 | 1024 | num_params = 0 | |
| 741 |
2/2✓ Branch 0 taken 3058 times.
✓ Branch 1 taken 1024 times.
|
4082 | do l = 1, this%num_layers |
| 742 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 3058 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3058 times.
|
4082 | num_params = num_params + this%model(l)%layer%get_num_params() |
| 743 | end do | ||
| 744 | |||
| 745 | 1024 | end function get_num_params | |
| 746 | !!!############################################################################# | ||
| 747 | |||
| 748 | |||
| 749 | !!!############################################################################# | ||
| 750 | !!! get learnable parameters | ||
| 751 | !!!############################################################################# | ||
| 752 | 507 | pure module function get_params(this) result(params) | |
| 753 | implicit none | ||
| 754 | class(network_type), intent(in) :: this | ||
| 755 | real(real12), allocatable, dimension(:) :: params | ||
| 756 | |||
| 757 | integer :: l, start_idx, end_idx | ||
| 758 | |||
| 759 | 507 | start_idx = 0 | |
| 760 | 507 | end_idx = 0 | |
| 761 |
13/24✓ Branch 1 taken 507 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 507 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 507 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 507 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 507 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 507 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 507 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 507 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 507 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 507 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 507 times.
✓ Branch 30 taken 15984 times.
✓ Branch 31 taken 507 times.
|
16491 | allocate(params(this%get_num_params()), source=0._real12) |
| 762 |
2/2✓ Branch 0 taken 1513 times.
✓ Branch 1 taken 507 times.
|
2020 | do l = 1, this%num_layers |
| 763 | 507 | select type(current => this%model(l)%layer) | |
| 764 | class is(learnable_layer_type) | ||
| 765 | 1006 | start_idx = end_idx + 1 | |
| 766 | 2012 | end_idx = end_idx + current%get_num_params() | |
| 767 |
6/10✗ Branch 1 not taken.
✓ Branch 2 taken 1006 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1006 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1006 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1006 times.
✓ Branch 14 taken 15984 times.
✓ Branch 15 taken 1006 times.
|
16990 | params(start_idx:end_idx) = current%get_params() |
| 768 | end select | ||
| 769 | end do | ||
| 770 | |||
| 771 | 507 | end function get_params | |
| 772 | !!!############################################################################# | ||
| 773 | |||
| 774 | |||
| 775 | !!!############################################################################# | ||
| 776 | !!! set learnable parameters | ||
| 777 | !!!############################################################################# | ||
| 778 |
1/2✓ Branch 0 taken 507 times.
✗ Branch 1 not taken.
|
507 | module subroutine set_params(this, params) |
| 779 | implicit none | ||
| 780 | class(network_type), intent(inout) :: this | ||
| 781 | real(real12), dimension(:), intent(in) :: params | ||
| 782 | |||
| 783 | integer :: l, start_idx, end_idx | ||
| 784 | |||
| 785 | 507 | start_idx = 0 | |
| 786 | 507 | end_idx = 0 | |
| 787 |
2/2✓ Branch 0 taken 1513 times.
✓ Branch 1 taken 507 times.
|
2020 | do l = 1, this%num_layers |
| 788 | 507 | select type(current => this%model(l)%layer) | |
| 789 | class is(learnable_layer_type) | ||
| 790 | 1006 | start_idx = end_idx + 1 | |
| 791 | 2012 | end_idx = end_idx + current%get_num_params() | |
| 792 |
4/8✗ Branch 1 not taken.
✓ Branch 2 taken 1006 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1006 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1006 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1006 times.
|
1006 | call current%set_params(params(start_idx:end_idx)) |
| 793 | end select | ||
| 794 | end do | ||
| 795 | |||
| 796 | 507 | end subroutine set_params | |
| 797 | !!!############################################################################# | ||
| 798 | |||
| 799 | |||
| 800 | !!!############################################################################# | ||
| 801 | !!! get gradients | ||
| 802 | !!!############################################################################# | ||
| 803 | 509 | pure module function get_gradients(this) result(gradients) | |
| 804 | implicit none | ||
| 805 | class(network_type), intent(in) :: this | ||
| 806 | real(real12), allocatable, dimension(:) :: gradients | ||
| 807 | |||
| 808 | integer :: l, start_idx, end_idx | ||
| 809 | |||
| 810 | 509 | start_idx = 0 | |
| 811 | 509 | end_idx = 0 | |
| 812 |
13/24✓ Branch 1 taken 509 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 509 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 509 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 509 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 509 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 509 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 509 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 509 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 509 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 509 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 509 times.
✓ Branch 30 taken 16048 times.
✓ Branch 31 taken 509 times.
|
16557 | allocate(gradients(this%get_num_params()), source=0._real12) |
| 813 |
2/2✓ Branch 0 taken 1519 times.
✓ Branch 1 taken 509 times.
|
2028 | do l = 1, this%num_layers |
| 814 | 509 | select type(current => this%model(l)%layer) | |
| 815 | class is(learnable_layer_type) | ||
| 816 | 1010 | start_idx = end_idx + 1 | |
| 817 | 2020 | end_idx = end_idx + current%get_num_params() | |
| 818 |
6/10✗ Branch 1 not taken.
✓ Branch 2 taken 1010 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1010 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1010 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1010 times.
✓ Branch 14 taken 16048 times.
✓ Branch 15 taken 1010 times.
|
17058 | gradients(start_idx:end_idx) = current%get_gradients() |
| 819 | end select | ||
| 820 | end do | ||
| 821 | |||
| 822 | 509 | end function get_gradients | |
| 823 | !!!############################################################################# | ||
| 824 | |||
| 825 | |||
| 826 | !!!############################################################################# | ||
| 827 | !!! set gradients | ||
| 828 | !!!############################################################################# | ||
| 829 | 2 | module subroutine set_gradients(this, gradients) | |
| 830 | implicit none | ||
| 831 | class(network_type), intent(inout) :: this | ||
| 832 | real(real12), dimension(..), intent(in) :: gradients | ||
| 833 | |||
| 834 | integer :: l, start_idx, end_idx | ||
| 835 | |||
| 836 | 2 | start_idx = 0 | |
| 837 | 2 | end_idx = 0 | |
| 838 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | do l = 1, this%num_layers |
| 839 | 2 | select type(current => this%model(l)%layer) | |
| 840 | class is(learnable_layer_type) | ||
| 841 | 4 | start_idx = end_idx + 1 | |
| 842 | 8 | end_idx = end_idx + current%get_num_params() | |
| 843 | 4 | select rank(gradients) | |
| 844 | rank(0) | ||
| 845 | 2 | call current%set_gradients(gradients) | |
| 846 | rank(1) | ||
| 847 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | call current%set_gradients(gradients(start_idx:end_idx)) |
| 848 | end select | ||
| 849 | end select | ||
| 850 | end do | ||
| 851 | |||
| 852 | 2 | end subroutine set_gradients | |
| 853 | !!!############################################################################# | ||
| 854 | |||
| 855 | |||
| 856 | !!!############################################################################# | ||
| 857 | !!! reset gradients | ||
| 858 | !!!############################################################################# | ||
| 859 | 507 | module subroutine reset_gradients(this) | |
| 860 | implicit none | ||
| 861 | class(network_type), intent(inout) :: this | ||
| 862 | |||
| 863 | integer :: l | ||
| 864 | |||
| 865 |
2/2✓ Branch 0 taken 1513 times.
✓ Branch 1 taken 507 times.
|
2020 | do l = 1, this%num_layers |
| 866 | 507 | select type(current => this%model(l)%layer) | |
| 867 | class is(learnable_layer_type) | ||
| 868 | 1006 | call current%set_gradients(0._real12) | |
| 869 | end select | ||
| 870 | end do | ||
| 871 | |||
| 872 | 507 | end subroutine reset_gradients | |
| 873 | !!!############################################################################# | ||
| 874 | |||
| 875 | |||
| 876 | !!!##########################################################################!!! | ||
| 877 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
| 878 | !!!##########################################################################!!! | ||
| 879 | |||
| 880 | |||
| 881 | !!!############################################################################# | ||
| 882 | !!! forward pass | ||
| 883 | !!!############################################################################# | ||
| 884 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 521 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
521 | pure module subroutine forward_1d(this, input, addit_input, layer) |
| 885 | implicit none | ||
| 886 | class(network_type), intent(inout) :: this | ||
| 887 | real(real12), dimension(..), intent(in) :: input | ||
| 888 | |||
| 889 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 890 | integer, optional, intent(in) :: layer | ||
| 891 | |||
| 892 | integer :: i | ||
| 893 | |||
| 894 | |||
| 895 | !!-------------------------------------------------------------------------- | ||
| 896 | !! initialise optional arguments | ||
| 897 | !!-------------------------------------------------------------------------- | ||
| 898 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 521 times.
|
521 | if(present(layer).and.present(addit_input))then |
| 899 | ✗ | select type(previous => this%model(layer-1)%layer) | |
| 900 | type is(flatten1d_layer_type) | ||
| 901 | ✗ | previous%output(size(previous%di(:,:,1)) - & | |
| 902 | ✗ | size(addit_input,1)+1:,:) = addit_input | |
| 903 | type is(flatten2d_layer_type) | ||
| 904 | ✗ | previous%output(size(previous%di(:,:,:,1)) - & | |
| 905 | ✗ | size(addit_input,1)+1:,:) = addit_input | |
| 906 | type is(flatten3d_layer_type) | ||
| 907 | ✗ | previous%output(size(previous%di(:,:,:,:,1)) - & | |
| 908 | ✗ | size(addit_input,1)+1:,:) = addit_input | |
| 909 | type is(flatten4d_layer_type) | ||
| 910 | ✗ | previous%output(size(previous%di(:,:,:,:,:,1)) - & | |
| 911 | ✗ | size(addit_input,1)+1:,:) = addit_input | |
| 912 | end select | ||
| 913 | end if | ||
| 914 | |||
| 915 | |||
| 916 | !! Forward pass (first layer) | ||
| 917 | !!-------------------------------------------------------------------------- | ||
| 918 | 1042 | select type(current => this%model(1)%layer) | |
| 919 | class is(input_layer_type) | ||
| 920 | 521 | call current%set(input) | |
| 921 | end select | ||
| 922 | |||
| 923 | !! Forward pass | ||
| 924 | !!-------------------------------------------------------------------------- | ||
| 925 |
2/2✓ Branch 0 taken 1026 times.
✓ Branch 1 taken 521 times.
|
1547 | do i=2,this%num_layers,1 |
| 926 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1026 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1026 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1026 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1026 times.
|
1547 | call this%model(i)%forward(this%model(i-1)) |
| 927 | end do | ||
| 928 | |||
| 929 | 521 | end subroutine forward_1d | |
| 930 | !!!############################################################################# | ||
| 931 | |||
| 932 | |||
| 933 | !!!############################################################################# | ||
| 934 | !!! backward pass | ||
| 935 | !!!############################################################################# | ||
| 936 |
1/2✓ Branch 0 taken 510 times.
✗ Branch 1 not taken.
|
510 | pure module subroutine backward_1d(this, output) |
| 937 | implicit none | ||
| 938 | class(network_type), intent(inout) :: this | ||
| 939 | real(real12), dimension(:,:), intent(in) :: output | ||
| 940 | |||
| 941 | integer :: i | ||
| 942 | 510 | real(real12), allocatable, dimension(:,:) :: predicted | |
| 943 | |||
| 944 | |||
| 945 | !! Backward pass (final layer) | ||
| 946 | !!------------------------------------------------------------------- | ||
| 947 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 510 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 510 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 510 times.
✓ Branch 8 taken 510 times.
✗ Branch 9 not taken.
|
510 | call this%model(this%num_layers)%layer%get_output(predicted) |
| 948 | 1020 | call this%model(this%num_layers)%backward( & | |
| 949 | 1020 | this%model(this%num_layers-1), & | |
| 950 |
23/46✗ Branch 0 not taken.
✓ Branch 1 taken 510 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 510 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 510 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 510 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 510 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 510 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 510 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 510 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 510 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 510 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 510 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 510 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 510 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 510 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 510 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 510 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 510 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 510 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 510 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 510 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 510 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 510 times.
✓ Branch 54 taken 510 times.
✗ Branch 55 not taken.
|
510 | this%get_loss_deriv(predicted, output)) |
| 951 | |||
| 952 | |||
| 953 | !! Backward pass | ||
| 954 | !!------------------------------------------------------------------- | ||
| 955 |
2/2✓ Branch 0 taken 502 times.
✓ Branch 1 taken 510 times.
|
1012 | do i=this%num_layers-1,2,-1 |
| 956 | 510 | select type(next => this%model(i+1)%layer) | |
| 957 | type is(batchnorm1d_layer_type) | ||
| 958 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 959 | type is(batchnorm2d_layer_type) | ||
| 960 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 961 | type is(batchnorm3d_layer_type) | ||
| 962 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 963 | |||
| 964 | type is(conv1d_layer_type) | ||
| 965 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 966 | type is(conv2d_layer_type) | ||
| 967 |
20/40✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
|
1 | call this%model(i)%backward(this%model(i-1),next%di) |
| 968 | type is(conv3d_layer_type) | ||
| 969 |
24/48✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 1 times.
|
1 | call this%model(i)%backward(this%model(i-1),next%di) |
| 970 | |||
| 971 | type is(dropout_layer_type) | ||
| 972 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 973 | type is(dropblock2d_layer_type) | ||
| 974 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 975 | type is(dropblock3d_layer_type) | ||
| 976 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 977 | |||
| 978 | type is(avgpool1d_layer_type) | ||
| 979 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 980 | type is(avgpool2d_layer_type) | ||
| 981 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 982 | type is(avgpool3d_layer_type) | ||
| 983 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 984 | type is(maxpool1d_layer_type) | ||
| 985 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 986 | type is(maxpool2d_layer_type) | ||
| 987 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 988 | type is(maxpool3d_layer_type) | ||
| 989 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 990 | |||
| 991 | type is(flatten1d_layer_type) | ||
| 992 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 993 | type is(flatten2d_layer_type) | ||
| 994 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 995 | type is(flatten3d_layer_type) | ||
| 996 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 997 | type is(flatten4d_layer_type) | ||
| 998 | ✗ | call this%model(i)%backward(this%model(i-1),next%di) | |
| 999 | |||
| 1000 | type is(full_layer_type) | ||
| 1001 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 500 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 500 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 500 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 500 times.
|
500 | call this%model(i)%backward(this%model(i-1),next%di) |
| 1002 | end select | ||
| 1003 | end do | ||
| 1004 | |||
| 1005 |
1/2✓ Branch 0 taken 510 times.
✗ Branch 1 not taken.
|
510 | end subroutine backward_1d |
| 1006 | !!!############################################################################# | ||
| 1007 | |||
| 1008 | |||
| 1009 | !!!############################################################################# | ||
| 1010 | !!! update weights and biases | ||
| 1011 | !!!############################################################################# | ||
| 1012 | 507 | module subroutine update(this) | |
| 1013 | implicit none | ||
| 1014 | class(network_type), intent(inout) :: this | ||
| 1015 | 507 | real(real12), allocatable, dimension(:) :: params, gradients | |
| 1016 | |||
| 1017 | integer :: i | ||
| 1018 | |||
| 1019 | |||
| 1020 | !!------------------------------------------------------------------- | ||
| 1021 | !! Update layers of learnable layer types | ||
| 1022 | !!------------------------------------------------------------------- | ||
| 1023 |
4/8✗ Branch 1 not taken.
✓ Branch 2 taken 507 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 507 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 15984 times.
✓ Branch 8 taken 507 times.
|
16491 | params = this%get_params() |
| 1024 |
4/8✗ Branch 1 not taken.
✓ Branch 2 taken 507 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 507 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 15984 times.
✓ Branch 8 taken 507 times.
|
16491 | gradients = this%get_gradients() |
| 1025 | 507 | call this%optimiser%minimise(params, gradients) | |
| 1026 | 507 | call this%set_params(params) | |
| 1027 | 507 | call this%reset_gradients() | |
| 1028 | |||
| 1029 | !! Increment optimiser iteration counter | ||
| 1030 | !!------------------------------------------------------------------- | ||
| 1031 | 507 | this%optimiser%iter = this%optimiser%iter + 1 | |
| 1032 | |||
| 1033 |
2/4✓ Branch 0 taken 507 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 507 times.
✗ Branch 3 not taken.
|
507 | end subroutine update |
| 1034 | !!!############################################################################# | ||
| 1035 | |||
| 1036 | |||
| 1037 | !!!##########################################################################!!! | ||
| 1038 | !!! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !!! | ||
| 1039 | !!!##########################################################################!!! | ||
| 1040 | |||
| 1041 | |||
| 1042 | !!!############################################################################# | ||
| 1043 | !!! training loop | ||
| 1044 | !!! ... loops over num_epoch number of epochs | ||
| 1045 | !!! ... i.e. it trains on the same datapoints num_epoch times | ||
| 1046 | !!!############################################################################# | ||
| 1047 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | module subroutine train(this, input, output, num_epochs, batch_size, & |
| 1048 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1 | addit_input, addit_layer, & |
| 1049 | plateau_threshold, shuffle_batches, batch_print_step, verbose) | ||
| 1050 | use infile_tools, only: stop_check | ||
| 1051 | implicit none | ||
| 1052 | class(network_type), intent(inout) :: this | ||
| 1053 | real(real12), dimension(..), intent(in) :: input | ||
| 1054 | class(*), dimension(:,:), intent(in) :: output | ||
| 1055 | integer, intent(in) :: num_epochs | ||
| 1056 | integer, optional, intent(in) :: batch_size !! deprecated | ||
| 1057 | |||
| 1058 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 1059 | integer, optional, intent(in) :: addit_layer | ||
| 1060 | |||
| 1061 | real(real12), optional, intent(in) :: plateau_threshold | ||
| 1062 | logical, optional, intent(in) :: shuffle_batches | ||
| 1063 | integer, optional, intent(in) :: batch_print_step | ||
| 1064 | integer, optional, intent(in) :: verbose | ||
| 1065 | |||
| 1066 | !! training and testing monitoring | ||
| 1067 | real(real12) :: batch_loss, batch_accuracy, avg_loss, avg_accuracy | ||
| 1068 | 1 | real(real12), allocatable, dimension(:,:) :: y_true | |
| 1069 | |||
| 1070 | !! learning parameters | ||
| 1071 | integer :: l, num_samples | ||
| 1072 | integer :: num_batches | ||
| 1073 | integer :: converged | ||
| 1074 | integer :: history_length | ||
| 1075 | integer :: verbose_ = 0 | ||
| 1076 | integer :: batch_print_step_ = 20 | ||
| 1077 | real(real12) :: plateau_threshold_ = 1.E-2_real12 | ||
| 1078 | logical :: shuffle_batches_ = .true. | ||
| 1079 | |||
| 1080 | !! training loop variables | ||
| 1081 | integer :: epoch, batch, start_index, end_index | ||
| 1082 | 1 | integer, allocatable, dimension(:) :: batch_order | |
| 1083 | |||
| 1084 | integer :: i, time, time_old, clock_rate | ||
| 1085 | |||
| 1086 | #ifdef _OPENMP | ||
| 1087 | type(network_type) :: this_copy | ||
| 1088 | real(real12), allocatable, dimension(:,:) :: input_slice, addit_input_slice | ||
| 1089 | #endif | ||
| 1090 | integer :: timer_start = 0, timer_stop = 0, timer_sum = 0, timer_tot = 0 | ||
| 1091 | |||
| 1092 | |||
| 1093 | !!!----------------------------------------------------------------------------- | ||
| 1094 | !!! initialise optional arguments | ||
| 1095 | !!!----------------------------------------------------------------------------- | ||
| 1096 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(plateau_threshold)) plateau_threshold_ = plateau_threshold |
| 1097 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(shuffle_batches)) shuffle_batches_ = shuffle_batches |
| 1098 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(batch_print_step)) batch_print_step_ = batch_print_step |
| 1099 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(verbose)) verbose_ = verbose |
| 1100 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | if(present(batch_size)) this%batch_size = batch_size |
| 1101 | |||
| 1102 | |||
| 1103 | !!!----------------------------------------------------------------------------- | ||
| 1104 | !!! initialise monitoring variables | ||
| 1105 | !!!----------------------------------------------------------------------------- | ||
| 1106 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | history_length = max(ceiling(500._real12/this%batch_size),1) |
| 1107 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do i=1,size(this%metrics,dim=1) |
| 1108 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
|
2 | if(allocated(this%metrics(i)%history)) & |
| 1109 | ✗ | deallocate(this%metrics(i)%history) | |
| 1110 |
9/18✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✓ Branch 6 taken 2 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 2 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 2 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 2 times.
|
2 | allocate(this%metrics(i)%history(history_length)) |
| 1111 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✓ Branch 6 taken 1000 times.
✓ Branch 7 taken 2 times.
|
1003 | this%metrics(i)%history = -huge(1._real12) |
| 1112 | end do | ||
| 1113 | |||
| 1114 | |||
| 1115 | !!!----------------------------------------------------------------------------- | ||
| 1116 | !!! allocate predicted and true label sets | ||
| 1117 | !!!----------------------------------------------------------------------------- | ||
| 1118 |
21/38✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✓ Branch 47 taken 2 times.
✓ Branch 48 taken 1 times.
|
4 | allocate(y_true(this%num_outputs,this%batch_size), source = 0._real12) |
| 1119 | |||
| 1120 | |||
| 1121 | !!!----------------------------------------------------------------------------- | ||
| 1122 | !!! if parallel, initialise slices | ||
| 1123 | !!!----------------------------------------------------------------------------- | ||
| 1124 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | num_batches = size(output,dim=2) / this%batch_size |
| 1125 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | allocate(batch_order(num_batches)) |
| 1126 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | do batch = 1, num_batches |
| 1127 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
2 | batch_order(batch) = batch |
| 1128 | end do | ||
| 1129 | |||
| 1130 | |||
| 1131 | !!!----------------------------------------------------------------------------- | ||
| 1132 | !!! get number of samples | ||
| 1133 | !!!----------------------------------------------------------------------------- | ||
| 1134 | select rank(input) | ||
| 1135 | rank(1) | ||
| 1136 | ✗ | write(*,*) "Cannot check number of samples in rank 1 input" | |
| 1137 | rank default | ||
| 1138 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | num_samples = size(input,rank(input)) |
| 1139 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
|
1 | if(size(output,2).ne.num_samples)then |
| 1140 | ✗ | write(0,*) "ERROR: number of samples in input and output do not match" | |
| 1141 | ✗ | stop "Exiting..." | |
| 1142 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
|
1 | elseif(size(output,1).ne.this%num_outputs)then |
| 1143 | ✗ | write(0,*) "ERROR: number of outputs in output does not match network" | |
| 1144 | ✗ | stop "Exiting..." | |
| 1145 | end if | ||
| 1146 | end select | ||
| 1147 | |||
| 1148 | |||
| 1149 | !!!----------------------------------------------------------------------------- | ||
| 1150 | !!! set/reset batch size for training | ||
| 1151 | !!!----------------------------------------------------------------------------- | ||
| 1152 | 1 | call this%set_batch_size(this%batch_size) | |
| 1153 | |||
| 1154 | |||
| 1155 | |||
| 1156 | !!!----------------------------------------------------------------------------- | ||
| 1157 | !!! turn off inference booleans | ||
| 1158 | !!!----------------------------------------------------------------------------- | ||
| 1159 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
|
4 | do l=1,this%num_layers |
| 1160 | 1 | select type(current => this%model(l)%layer) | |
| 1161 | class is(drop_layer_type) | ||
| 1162 | ✗ | current%inference = .false. | |
| 1163 | end select | ||
| 1164 | end do | ||
| 1165 | |||
| 1166 | |||
| 1167 | !!!----------------------------------------------------------------------------- | ||
| 1168 | !!! query system clock | ||
| 1169 | !!!----------------------------------------------------------------------------- | ||
| 1170 | 1 | call system_clock(time, count_rate = clock_rate) | |
| 1171 | |||
| 1172 | |||
| 1173 |
1/2✓ Branch 0 taken 500 times.
✗ Branch 1 not taken.
|
500 | epoch_loop: do epoch = 1, num_epochs |
| 1174 | !!----------------------------------------------------------------------- | ||
| 1175 | !! shuffle batch order at the start of each epoch | ||
| 1176 | !!----------------------------------------------------------------------- | ||
| 1177 |
1/2✓ Branch 0 taken 500 times.
✗ Branch 1 not taken.
|
500 | if(shuffle_batches_)then |
| 1178 | 500 | call shuffle(batch_order) | |
| 1179 | end if | ||
| 1180 | |||
| 1181 | 500 | avg_loss = 0._real12 | |
| 1182 | 500 | avg_accuracy = 0._real12 | |
| 1183 | |||
| 1184 | !!----------------------------------------------------------------------- | ||
| 1185 | !! batch loop | ||
| 1186 | !! ... split data up into minibatches for training | ||
| 1187 | !!----------------------------------------------------------------------- | ||
| 1188 |
2/2✓ Branch 0 taken 500 times.
✓ Branch 1 taken 499 times.
|
999 | batch_loop: do batch = 1, num_batches |
| 1189 | |||
| 1190 | |||
| 1191 | !! set batch start and end index | ||
| 1192 | !!-------------------------------------------------------------------- | ||
| 1193 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
|
500 | start_index = (batch_order(batch) - 1) * this%batch_size + 1 |
| 1194 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
|
500 | end_index = batch_order(batch) * this%batch_size |
| 1195 | |||
| 1196 | |||
| 1197 | !! reinitialise variables | ||
| 1198 | !!-------------------------------------------------------------------- | ||
| 1199 | ✗ | select type(output) | |
| 1200 | type is(integer) | ||
| 1201 | ✗ | y_true(:,:) = real(output(:,start_index:end_index:1),real12) | |
| 1202 | type is(real) | ||
| 1203 |
27/50✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 500 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 500 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 500 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 500 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 500 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 500 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 500 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 500 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 500 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 500 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 500 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 500 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 500 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 500 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 500 times.
✓ Branch 63 taken 500 times.
✓ Branch 64 taken 500 times.
✓ Branch 65 taken 1000 times.
✓ Branch 66 taken 500 times.
|
2000 | y_true(:,:) = output(:,start_index:end_index:1) |
| 1204 | end select | ||
| 1205 | |||
| 1206 | |||
| 1207 | !! Forward pass | ||
| 1208 | !!----------------------------------------------------------------- | ||
| 1209 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
|
500 | if(present(addit_input).and.present(addit_layer))then |
| 1210 | call this%forward(get_sample(input,start_index,end_index),& | ||
| 1211 | ✗ | addit_input(:,start_index:end_index),addit_layer) | |
| 1212 | else | ||
| 1213 | 500 | call this%forward(get_sample(input,start_index,end_index)) | |
| 1214 | end if | ||
| 1215 | |||
| 1216 | |||
| 1217 | !! Backward pass and store predicted output | ||
| 1218 | !!----------------------------------------------------------------- | ||
| 1219 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
|
500 | call this%backward(y_true(:,:)) |
| 1220 | 1000 | select type(current => this%model(this%num_layers)%layer) | |
| 1221 | type is(full_layer_type) | ||
| 1222 | !! compute loss and accuracy (for monitoring) | ||
| 1223 | !!------------------------------------------------------------------- | ||
| 1224 | batch_loss = sum( & | ||
| 1225 | this%get_loss( & | ||
| 1226 | 4000 | current%output(:,1:this%batch_size), & | |
| 1227 |
32/60✗ Branch 1 not taken.
✓ Branch 2 taken 500 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 500 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 500 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 500 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 500 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 500 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 500 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 500 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 500 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 500 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 500 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 500 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 500 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 500 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 500 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 500 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 500 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 500 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 500 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 500 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 500 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 500 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 500 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 500 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 500 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 500 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 500 times.
✓ Branch 70 taken 500 times.
✗ Branch 71 not taken.
✓ Branch 73 taken 500 times.
✓ Branch 74 taken 500 times.
✓ Branch 75 taken 1000 times.
✓ Branch 76 taken 500 times.
|
2000 | y_true(:,1:this%batch_size))) |
| 1228 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
|
500 | select type(output) |
| 1229 | type is(integer) | ||
| 1230 | batch_accuracy = sum(categorical_score( & | ||
| 1231 | ✗ | current%output(:,1:this%batch_size), & | |
| 1232 | ✗ | output(:,start_index:end_index))) | |
| 1233 | type is(real) | ||
| 1234 | batch_accuracy = sum(mae_score( & | ||
| 1235 | 4000 | current%output(:,1:this%batch_size), & | |
| 1236 |
30/58✗ Branch 0 not taken.
✓ Branch 1 taken 500 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 500 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 500 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 500 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 500 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 500 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 500 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 500 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 500 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 500 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 500 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 500 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 500 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 500 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 500 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 500 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 500 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 500 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 500 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 500 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 500 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 500 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 500 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 500 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 500 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 500 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 500 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 500 times.
✓ Branch 72 taken 500 times.
✓ Branch 73 taken 500 times.
|
1000 | output(:,start_index:end_index))) |
| 1237 | end select | ||
| 1238 | class default | ||
| 1239 | ✗ | stop "ERROR: final layer not of type full_layer_type" | |
| 1240 | end select | ||
| 1241 | |||
| 1242 | |||
| 1243 | |||
| 1244 | !! Average metric over batch size and store | ||
| 1245 | !! Check metric convergence | ||
| 1246 | !!-------------------------------------------------------------------- | ||
| 1247 | 500 | avg_loss = avg_loss + batch_loss | |
| 1248 | 500 | avg_accuracy = avg_accuracy + batch_accuracy | |
| 1249 | 500 | this%metrics(1)%val = batch_loss / this%batch_size | |
| 1250 | 500 | this%metrics(2)%val = batch_accuracy / this%batch_size | |
| 1251 |
2/2✓ Branch 0 taken 999 times.
✓ Branch 1 taken 499 times.
|
1498 | do i = 1, size(this%metrics,dim=1) |
| 1252 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 999 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 999 times.
|
999 | call this%metrics(i)%check(plateau_threshold_, converged) |
| 1253 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 998 times.
|
1498 | if(converged.ne.0)then |
| 1254 | 1 | exit epoch_loop | |
| 1255 | end if | ||
| 1256 | end do | ||
| 1257 | |||
| 1258 | |||
| 1259 | !! update weights and biases using optimization algorithm | ||
| 1260 | !! ... (gradient descent) | ||
| 1261 | !!-------------------------------------------------------------------- | ||
| 1262 | !! STORE ADAM VALUES IN OPTIMISER | ||
| 1263 | 499 | call this%update() | |
| 1264 | |||
| 1265 | |||
| 1266 | !! print batch results | ||
| 1267 | !!-------------------------------------------------------------------- | ||
| 1268 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 499 times.
|
499 | if(abs(verbose_).gt.0.and.& |
| 1269 | (batch.eq.1.or.mod(batch,batch_print_step_).eq.0.E0))then | ||
| 1270 | write(6,'("epoch=",I0,", batch=",I0,& | ||
| 1271 | &", learning_rate=",F0.3,", loss=",F0.3,", accuracy=",F0.3)')& | ||
| 1272 | ✗ | epoch, batch, & | |
| 1273 | ✗ | this%optimiser%learning_rate, & | |
| 1274 | ✗ | avg_loss/(batch*this%batch_size), & | |
| 1275 | ✗ | avg_accuracy/(batch*this%batch_size) | |
| 1276 | end if | ||
| 1277 | |||
| 1278 | |||
| 1279 | !!! TESTING | ||
| 1280 | !#ifdef _OPENMP | ||
| 1281 | ! call system_clock(timer_start) | ||
| 1282 | ! call system_clock(timer_stop) | ||
| 1283 | ! timer_sum = timer_sum + timer_stop - timer_start | ||
| 1284 | ! timer_tot = timer_tot + timer_sum / omp_get_max_threads() | ||
| 1285 | !#else | ||
| 1286 | ! timer_tot = timer_tot + timer_sum | ||
| 1287 | !#endif | ||
| 1288 | ! timer_sum = 0 | ||
| 1289 | ! if(batch.gt.200)then | ||
| 1290 | ! time_old = time | ||
| 1291 | ! call system_clock(time) | ||
| 1292 | ! write(*,'("time check: ",F8.3," seconds")') real(time-time_old)/clock_rate | ||
| 1293 | ! !write(*,'("update timer: ",F8.3," seconds")') real(timer_tot)/clock_rate | ||
| 1294 | ! exit epoch_loop | ||
| 1295 | ! stop "THIS IS FOR TESTING PURPOSES" | ||
| 1296 | ! end if | ||
| 1297 | !!! | ||
| 1298 | |||
| 1299 | |||
| 1300 | !! time check | ||
| 1301 | !!-------------------------------------------------------------------- | ||
| 1302 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 499 times.
|
499 | if(verbose_.eq.-2)then |
| 1303 | ✗ | time_old = time | |
| 1304 | ✗ | call system_clock(time) | |
| 1305 | write(*,'("time check: ",F5.3," seconds")') & | ||
| 1306 | ✗ | real(time-time_old)/clock_rate | |
| 1307 | ✗ | time_old = time | |
| 1308 | end if | ||
| 1309 | |||
| 1310 | |||
| 1311 | !! check for user-name stop file | ||
| 1312 | !!-------------------------------------------------------------------- | ||
| 1313 |
1/2✗ Branch 1 not taken.
✓ Branch 2 taken 499 times.
|
998 | if(stop_check())then |
| 1314 | ✗ | write(0,*) "STOPCAR ENCOUNTERED" | |
| 1315 | ✗ | write(0,*) "Exiting training loop..." | |
| 1316 | ✗ | exit epoch_loop | |
| 1317 | end if | ||
| 1318 | |||
| 1319 | end do batch_loop | ||
| 1320 | |||
| 1321 | |||
| 1322 | !! print epoch summary results | ||
| 1323 | !!----------------------------------------------------------------------- | ||
| 1324 |
1/2✓ Branch 0 taken 499 times.
✗ Branch 1 not taken.
|
499 | if(verbose_.eq.0)then |
| 1325 | write(6,'("epoch=",I0,", batch=",I0,& | ||
| 1326 | &", learning_rate=",F0.3,", val_loss=",F0.3,& | ||
| 1327 | &", val_accuracy=",F0.3)') & | ||
| 1328 | 499 | epoch, batch, & | |
| 1329 | 499 | this%optimiser%learning_rate, & | |
| 1330 | 998 | this%metrics(1)%val, this%metrics(2)%val | |
| 1331 | end if | ||
| 1332 | |||
| 1333 | |||
| 1334 | end do epoch_loop | ||
| 1335 | |||
| 1336 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | end subroutine train |
| 1337 | !!!############################################################################# | ||
| 1338 | |||
| 1339 | |||
| 1340 | !!!############################################################################# | ||
| 1341 | !!! testing loop | ||
| 1342 | !!!############################################################################# | ||
| 1343 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | module subroutine test(this, input, output, & |
| 1344 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
1 | addit_input, addit_layer, & |
| 1345 | verbose) | ||
| 1346 | implicit none | ||
| 1347 | class(network_type), intent(inout) :: this | ||
| 1348 | real(real12), dimension(..), intent(in) :: input | ||
| 1349 | class(*), dimension(:,:), intent(in) :: output | ||
| 1350 | |||
| 1351 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 1352 | integer, optional, intent(in) :: addit_layer | ||
| 1353 | |||
| 1354 | integer, optional, intent(in) :: verbose | ||
| 1355 | |||
| 1356 | integer :: l, sample, num_samples | ||
| 1357 | integer :: verbose_, unit | ||
| 1358 | real(real12) :: acc_val, loss_val | ||
| 1359 | 1 | real(real12), allocatable, dimension(:) :: accuracy_list | |
| 1360 | 1 | real(real12), allocatable, dimension(:,:) :: predicted, y_true | |
| 1361 | |||
| 1362 | |||
| 1363 | !!!----------------------------------------------------------------------------- | ||
| 1364 | !!! initialise optional arguments | ||
| 1365 | !!!----------------------------------------------------------------------------- | ||
| 1366 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(verbose))then |
| 1367 | ✗ | verbose_ = verbose | |
| 1368 | else | ||
| 1369 | 1 | verbose_ = 0 | |
| 1370 | end if | ||
| 1371 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | num_samples = size(output, dim=2) |
| 1372 |
15/30✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
|
1 | allocate(predicted(size(output,1), num_samples)) |
| 1373 | |||
| 1374 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | this%metrics%val = 0._real12 |
| 1375 | 1 | acc_val = 0._real12 | |
| 1376 | 1 | loss_val = 0._real12 | |
| 1377 |
7/14✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | allocate(accuracy_list(num_samples)) |
| 1378 | |||
| 1379 | |||
| 1380 | ✗ | select type(output) | |
| 1381 | type is(integer) | ||
| 1382 | ✗ | y_true = real(output(:,:),real12) | |
| 1383 | type is(real) | ||
| 1384 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✓ Branch 39 taken 1 times.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
|
4 | y_true = output(:,:) |
| 1385 | end select | ||
| 1386 | |||
| 1387 | |||
| 1388 | !!!----------------------------------------------------------------------------- | ||
| 1389 | !!! reset batch size for testing | ||
| 1390 | !!!----------------------------------------------------------------------------- | ||
| 1391 | 1 | call this%set_batch_size(1) | |
| 1392 | |||
| 1393 | |||
| 1394 | !!!----------------------------------------------------------------------------- | ||
| 1395 | !!! turn on inference booleans | ||
| 1396 | !!!----------------------------------------------------------------------------- | ||
| 1397 |
2/2✓ Branch 0 taken 3 times.
✓ Branch 1 taken 1 times.
|
4 | do l=1,this%num_layers |
| 1398 | 1 | select type(current => this%model(l)%layer) | |
| 1399 | class is(drop_layer_type) | ||
| 1400 | ✗ | current%inference = .true. | |
| 1401 | end select | ||
| 1402 | end do | ||
| 1403 | |||
| 1404 | |||
| 1405 | !!!----------------------------------------------------------------------------- | ||
| 1406 | !!! testing loop | ||
| 1407 | !!!----------------------------------------------------------------------------- | ||
| 1408 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | test_loop: do sample = 1, num_samples |
| 1409 | |||
| 1410 | !! Forward pass | ||
| 1411 | !!----------------------------------------------------------------------- | ||
| 1412 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(present(addit_input).and.present(addit_layer))then |
| 1413 | call this%forward(get_sample(input,sample,sample),& | ||
| 1414 | ✗ | addit_input(:,sample:sample),addit_layer) | |
| 1415 | else | ||
| 1416 | 1 | call this%forward(get_sample(input,sample,sample)) | |
| 1417 | end if | ||
| 1418 | |||
| 1419 | |||
| 1420 | !! compute loss and accuracy (for monitoring) | ||
| 1421 | !!----------------------------------------------------------------------- | ||
| 1422 | 1 | select type(current => this%model(this%num_layers)%layer) | |
| 1423 | type is(full_layer_type) | ||
| 1424 | loss_val = sum(this%get_loss( & | ||
| 1425 | predicted = current%output, & | ||
| 1426 | !!!! JUST REPLACE y_true(:,sample) WITH output(:,sample) !!!! | ||
| 1427 | !!!! THERE IS NO REASON TO USE y_true, as it is just a copy !!!! | ||
| 1428 | !!!! get_loss should handle both integers and reals !!!! | ||
| 1429 | !!!! it does not. Instead just wrap real(output(:,sample),real12) !!!! | ||
| 1430 |
24/44✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 1 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✓ Branch 49 taken 1 times.
✓ Branch 50 taken 1 times.
✓ Branch 51 taken 2 times.
✓ Branch 52 taken 1 times.
|
4 | expected = y_true(:,sample:sample))) |
| 1431 | ✗ | select type(output) | |
| 1432 | type is(integer) | ||
| 1433 | ✗ | acc_val = sum(categorical_score(current%output,output(:,sample:sample))) | |
| 1434 | type is(real) | ||
| 1435 |
22/42✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✓ Branch 48 taken 1 times.
✓ Branch 49 taken 1 times.
|
2 | acc_val = sum(mae_score(current%output,output(:,sample:sample))) |
| 1436 | end select | ||
| 1437 | 1 | this%metrics(2)%val = this%metrics(2)%val + acc_val | |
| 1438 | 1 | this%metrics(1)%val = this%metrics(1)%val + loss_val | |
| 1439 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | accuracy_list(sample) = acc_val |
| 1440 |
16/30✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 1 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 1 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 1 times.
✓ Branch 41 taken 2 times.
✓ Branch 42 taken 1 times.
|
4 | predicted(:,sample) = current%output(:,1) |
| 1441 | end select | ||
| 1442 | |||
| 1443 | end do test_loop | ||
| 1444 | |||
| 1445 | |||
| 1446 | !! print testing results | ||
| 1447 | !!-------------------------------------------------------------------- | ||
| 1448 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
|
1 | if(abs(verbose_).gt.1)then |
| 1449 | ✗ | open(file="test_output.out",newunit=unit) | |
| 1450 | ✗ | select type(final_layer => this%model(this%num_layers)%layer) | |
| 1451 | type is(full_layer_type) | ||
| 1452 | ✗ | test_loop: do concurrent(sample = 1:num_samples) | |
| 1453 | ✗ | select type(output) | |
| 1454 | type is(integer) | ||
| 1455 | write(unit,'(I4," Expected=",I3,", Got=",I3,", Accuracy=",F0.3)') & | ||
| 1456 | ✗ | sample, & | |
| 1457 | ✗ | maxloc(output(:,sample)), maxloc(predicted(:,sample),dim=1)-1, & | |
| 1458 | ✗ | accuracy_list(sample) | |
| 1459 | type is(real) | ||
| 1460 | write(unit,'(I4," Expected=",F0.3,", Got=",F0.3,", Accuracy=",F0.3)') & | ||
| 1461 | ✗ | sample, & | |
| 1462 | ✗ | output(:,sample), predicted(:,sample), & | |
| 1463 | ✗ | accuracy_list(sample) | |
| 1464 | end select | ||
| 1465 | end do test_loop | ||
| 1466 | end select | ||
| 1467 | ✗ | close(unit) | |
| 1468 | end if | ||
| 1469 | |||
| 1470 | |||
| 1471 | !! normalise metrics by number of samples | ||
| 1472 | !!-------------------------------------------------------------------- | ||
| 1473 | 1 | this%accuracy = this%metrics(2)%val/real(num_samples) | |
| 1474 | 1 | this%loss = this%metrics(1)%val/real(num_samples) | |
| 1475 | |||
| 1476 |
3/6✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
|
1 | end subroutine test |
| 1477 | !!!############################################################################# | ||
| 1478 | |||
| 1479 | |||
| 1480 | !!!############################################################################# | ||
| 1481 | !!! predict outputs from input data using trained network | ||
| 1482 | !!!############################################################################# | ||
| 1483 | 8 | module function predict_1d(this, input, & | |
| 1484 |
1/6✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
|
8 | addit_input, addit_layer, & |
| 1485 | verbose) result(output) | ||
| 1486 | implicit none | ||
| 1487 | class(network_type), intent(inout) :: this | ||
| 1488 | real(real12), dimension(..), intent(in) :: input | ||
| 1489 | |||
| 1490 | real(real12), dimension(:,:), optional, intent(in) :: addit_input | ||
| 1491 | integer, optional, intent(in) :: addit_layer | ||
| 1492 | |||
| 1493 | integer, optional, intent(in) :: verbose | ||
| 1494 | |||
| 1495 | real(real12), dimension(:,:), allocatable :: output | ||
| 1496 | |||
| 1497 | integer :: verbose_, batch_size | ||
| 1498 | |||
| 1499 | |||
| 1500 | !!!----------------------------------------------------------------------------- | ||
| 1501 | !!! initialise optional arguments | ||
| 1502 | !!!----------------------------------------------------------------------------- | ||
| 1503 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if(present(verbose))then |
| 1504 | ✗ | verbose_ = verbose | |
| 1505 | else | ||
| 1506 | 8 | verbose_ = 0 | |
| 1507 | end if | ||
| 1508 | |||
| 1509 | select rank(input) | ||
| 1510 | rank(2) | ||
| 1511 | 8 | batch_size = size(input,dim=2) | |
| 1512 | rank(3) | ||
| 1513 | ✗ | batch_size = size(input,dim=3) | |
| 1514 | rank(4) | ||
| 1515 | ✗ | batch_size = size(input,dim=4) | |
| 1516 | rank(5) | ||
| 1517 | ✗ | batch_size = size(input,dim=5) | |
| 1518 | rank(6) | ||
| 1519 | ✗ | batch_size = size(input,dim=6) | |
| 1520 | rank default | ||
| 1521 | ✗ | batch_size = size(input,dim=rank(input)) | |
| 1522 | end select | ||
| 1523 |
9/18✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 8 times.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 8 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
|
8 | allocate(output(this%num_outputs,batch_size)) |
| 1524 | |||
| 1525 | |||
| 1526 | !!!----------------------------------------------------------------------------- | ||
| 1527 | !!! reset batch size for testing | ||
| 1528 | !!!----------------------------------------------------------------------------- | ||
| 1529 | 8 | call this%set_batch_size(batch_size) | |
| 1530 | |||
| 1531 | |||
| 1532 | !!!----------------------------------------------------------------------------- | ||
| 1533 | !!! predict | ||
| 1534 | !!!----------------------------------------------------------------------------- | ||
| 1535 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
|
8 | if(present(addit_input).and.present(addit_layer))then |
| 1536 | call this%forward(get_sample(input,1,batch_size),& | ||
| 1537 | ✗ | addit_input(:,1:batch_size),addit_layer) | |
| 1538 | else | ||
| 1539 | 8 | call this%forward(get_sample(input,1,batch_size)) | |
| 1540 | end if | ||
| 1541 | |||
| 1542 | 16 | select type(current => this%model(this%num_layers)%layer) | |
| 1543 | type is(full_layer_type) | ||
| 1544 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✓ Branch 24 taken 8 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 8 times.
✓ Branch 33 taken 8 times.
✓ Branch 34 taken 8 times.
✓ Branch 35 taken 8 times.
|
24 | output = current%output(:,1:batch_size) |
| 1545 | end select | ||
| 1546 | |||
| 1547 | 8 | end function predict_1d | |
| 1548 | !!!############################################################################# | ||
| 1549 | |||
| 1550 |
126/257✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✓ Branch 7 taken 2 times.
✓ Branch 8 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 2 times.
✓ Branch 14 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✓ Branch 17 taken 12 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 11 times.
✓ Branch 22 taken 7 times.
✓ Branch 23 taken 4 times.
✓ Branch 25 taken 1 times.
✓ Branch 26 taken 3 times.
✓ Branch 28 taken 1 times.
✓ Branch 29 taken 2 times.
✓ Branch 31 taken 1 times.
✓ Branch 32 taken 1 times.
✓ Branch 34 taken 1 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 12 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 12 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 12 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 11 times.
✓ Branch 43 taken 1 times.
✓ Branch 44 taken 1 times.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✓ Branch 47 taken 7 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 7 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 7 times.
✗ Branch 55 not taken.
✗ Branch 56 not taken.
✓ Branch 58 taken 2 times.
✗ Branch 59 not taken.
✓ Branch 61 taken 2 times.
✗ Branch 62 not taken.
✗ Branch 64 not taken.
✓ Branch 65 taken 7 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 7 times.
✓ Branch 71 taken 4 times.
✓ Branch 72 taken 3 times.
✗ Branch 73 not taken.
✓ Branch 74 taken 2 times.
✗ Branch 76 not taken.
✓ Branch 77 taken 2 times.
✗ Branch 80 not taken.
✓ Branch 81 taken 2 times.
✗ Branch 82 not taken.
✓ Branch 83 taken 2 times.
✗ Branch 85 not taken.
✓ Branch 86 taken 2 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 2 times.
✓ Branch 90 taken 509 times.
✗ Branch 91 not taken.
✓ Branch 92 taken 509 times.
✗ Branch 93 not taken.
✓ Branch 94 taken 509 times.
✗ Branch 95 not taken.
✗ Branch 96 not taken.
✗ Branch 97 not taken.
✗ Branch 98 not taken.
✗ Branch 99 not taken.
✗ Branch 100 not taken.
✗ Branch 101 not taken.
✗ Branch 102 not taken.
✗ Branch 103 not taken.
✗ Branch 104 not taken.
✓ Branch 105 taken 1513 times.
✗ Branch 107 not taken.
✓ Branch 108 taken 1513 times.
✓ Branch 111 taken 1006 times.
✓ Branch 112 taken 507 times.
✗ Branch 113 not taken.
✓ Branch 114 taken 1513 times.
✗ Branch 116 not taken.
✓ Branch 117 taken 1513 times.
✓ Branch 120 taken 1006 times.
✓ Branch 121 taken 507 times.
✗ Branch 122 not taken.
✓ Branch 123 taken 1519 times.
✗ Branch 125 not taken.
✓ Branch 126 taken 1519 times.
✓ Branch 129 taken 1010 times.
✓ Branch 130 taken 509 times.
✗ Branch 131 not taken.
✓ Branch 132 taken 6 times.
✗ Branch 134 not taken.
✓ Branch 135 taken 6 times.
✓ Branch 138 taken 4 times.
✓ Branch 139 taken 2 times.
✓ Branch 140 taken 2 times.
✓ Branch 141 taken 2 times.
✓ Branch 142 taken 2 times.
✗ Branch 143 not taken.
✓ Branch 144 taken 2 times.
✓ Branch 145 taken 2 times.
✓ Branch 146 taken 2 times.
✗ Branch 147 not taken.
✗ Branch 148 not taken.
✓ Branch 149 taken 1513 times.
✗ Branch 151 not taken.
✓ Branch 152 taken 1513 times.
✓ Branch 155 taken 1006 times.
✓ Branch 156 taken 507 times.
✗ Branch 157 not taken.
✗ Branch 158 not taken.
✗ Branch 160 not taken.
✗ Branch 161 not taken.
✗ Branch 163 not taken.
✗ Branch 164 not taken.
✗ Branch 165 not taken.
✗ Branch 166 not taken.
✗ Branch 167 not taken.
✗ Branch 168 not taken.
✗ Branch 169 not taken.
✗ Branch 170 not taken.
✗ Branch 171 not taken.
✓ Branch 172 taken 521 times.
✗ Branch 174 not taken.
✓ Branch 175 taken 521 times.
✓ Branch 178 taken 521 times.
✗ Branch 179 not taken.
✗ Branch 180 not taken.
✓ Branch 181 taken 502 times.
✗ Branch 183 not taken.
✓ Branch 184 taken 502 times.
✗ Branch 186 not taken.
✓ Branch 187 taken 502 times.
✗ Branch 188 not taken.
✓ Branch 189 taken 502 times.
✗ Branch 190 not taken.
✓ Branch 191 taken 502 times.
✗ Branch 192 not taken.
✓ Branch 193 taken 502 times.
✓ Branch 194 taken 1 times.
✓ Branch 195 taken 501 times.
✓ Branch 196 taken 1 times.
✓ Branch 197 taken 500 times.
✗ Branch 198 not taken.
✓ Branch 199 taken 500 times.
✗ Branch 200 not taken.
✓ Branch 201 taken 500 times.
✗ Branch 202 not taken.
✓ Branch 203 taken 500 times.
✗ Branch 204 not taken.
✓ Branch 205 taken 500 times.
✗ Branch 206 not taken.
✓ Branch 207 taken 500 times.
✗ Branch 208 not taken.
✓ Branch 209 taken 500 times.
✗ Branch 210 not taken.
✓ Branch 211 taken 500 times.
✗ Branch 212 not taken.
✓ Branch 213 taken 500 times.
✗ Branch 214 not taken.
✓ Branch 215 taken 500 times.
✗ Branch 216 not taken.
✓ Branch 217 taken 500 times.
✗ Branch 218 not taken.
✓ Branch 219 taken 500 times.
✗ Branch 220 not taken.
✓ Branch 221 taken 500 times.
✗ Branch 222 not taken.
✓ Branch 223 taken 500 times.
✓ Branch 224 taken 500 times.
✗ Branch 225 not taken.
✓ Branch 226 taken 1 times.
✗ Branch 227 not taken.
✓ Branch 228 taken 1 times.
✗ Branch 229 not taken.
✗ Branch 230 not taken.
✓ Branch 231 taken 1 times.
✗ Branch 232 not taken.
✓ Branch 233 taken 3 times.
✗ Branch 235 not taken.
✓ Branch 236 taken 3 times.
✗ Branch 239 not taken.
✓ Branch 240 taken 3 times.
✗ Branch 241 not taken.
✓ Branch 242 taken 500 times.
✓ Branch 243 taken 500 times.
✗ Branch 244 not taken.
✗ Branch 245 not taken.
✓ Branch 246 taken 500 times.
✗ Branch 248 not taken.
✓ Branch 249 taken 500 times.
✓ Branch 251 taken 500 times.
✗ Branch 252 not taken.
✗ Branch 253 not taken.
✓ Branch 254 taken 500 times.
✓ Branch 255 taken 500 times.
✗ Branch 256 not taken.
✗ Branch 257 not taken.
✓ Branch 258 taken 1 times.
✓ Branch 259 taken 1 times.
✗ Branch 260 not taken.
✗ Branch 261 not taken.
✓ Branch 262 taken 3 times.
✗ Branch 264 not taken.
✓ Branch 265 taken 3 times.
✗ Branch 268 not taken.
✓ Branch 269 taken 3 times.
✗ Branch 270 not taken.
✓ Branch 271 taken 1 times.
✗ Branch 273 not taken.
✓ Branch 274 taken 1 times.
✓ Branch 276 taken 1 times.
✗ Branch 277 not taken.
✗ Branch 278 not taken.
✓ Branch 279 taken 1 times.
✓ Branch 280 taken 1 times.
✗ Branch 281 not taken.
✗ Branch 282 not taken.
✗ Branch 283 not taken.
✗ Branch 285 not taken.
✗ Branch 286 not taken.
✗ Branch 288 not taken.
✗ Branch 289 not taken.
✗ Branch 290 not taken.
✗ Branch 291 not taken.
✗ Branch 292 not taken.
✗ Branch 293 not taken.
✓ Branch 294 taken 8 times.
✗ Branch 295 not taken.
✓ Branch 296 taken 8 times.
✗ Branch 297 not taken.
✓ Branch 298 taken 8 times.
✗ Branch 299 not taken.
✗ Branch 300 not taken.
✗ Branch 301 not taken.
✗ Branch 302 not taken.
✗ Branch 303 not taken.
✗ Branch 304 not taken.
✗ Branch 305 not taken.
✗ Branch 306 not taken.
✗ Branch 307 not taken.
✗ Branch 308 not taken.
✓ Branch 309 taken 8 times.
✗ Branch 311 not taken.
✓ Branch 312 taken 8 times.
✓ Branch 314 taken 8 times.
✗ Branch 315 not taken.
|
9205 | end submodule network_submodule |
| 1551 | !!!############################################################################# | ||
| 1552 |