| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_conv | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | |||
| 4 | contains | ||
| 5 | |||
| 6 | !############################################################################### | ||
| 7 | − | module function conv1d(input, kernel, stride, dilation) result(output) | |
| 8 | !! 1D convolution operation | ||
| 9 | implicit none | ||
| 10 | |||
| 11 | ! Arguments | ||
| 12 | type(array_type), intent(in), target :: input | ||
| 13 | type(array_type), intent(in), target :: kernel | ||
| 14 | integer, intent(in) :: stride | ||
| 15 | integer, intent(in) :: dilation | ||
| 16 | type(array_type), pointer :: output | ||
| 17 | |||
| 18 | ! Local variables | ||
| 19 | integer :: i, k, c_in, c_out, s | ||
| 20 | integer :: i_in, k_idx | ||
| 21 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 22 | real(real32) :: conv_sum | ||
| 23 | integer, dimension(3) :: output_shape | ||
| 24 | |||
| 25 | ! Extract dimensions | ||
| 26 | ! input: [H_in, C_in, B] | ||
| 27 | ! kernel: [K, C_in, C_out] | ||
| 28 | − | input_h = input%shape(1) | |
| 29 | − | num_channels = input%shape(2) | |
| 30 | − | kernel_h = kernel%shape(1) | |
| 31 | − | num_filters = kernel%shape(3) | |
| 32 | |||
| 33 | ! Calculate output dimensions | ||
| 34 | output_h = (input_h - dilation*(kernel_h - 1) - 1) / & | ||
| 35 | − | stride + 1 | |
| 36 | − | output_shape = [output_h, num_filters, size(input%val, dim=2)] | |
| 37 | |||
| 38 | − | output => input%create_result(array_shape = output_shape) | |
| 39 | − | output%val = 0._real32 | |
| 40 | |||
| 41 | ! Perform convolution | ||
| 42 | do concurrent(s = 1:output_shape(3), c_out = 1:num_filters, & | ||
| 43 | − | i = 1:output_h) | |
| 44 | − | conv_sum = 0._real32 | |
| 45 | − | do c_in = 1, num_channels | |
| 46 | − | do k = 1, kernel_h | |
| 47 | − | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 48 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 49 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 50 | − | ( c_out - 1 ) * kernel_h * num_channels | |
| 51 | conv_sum = conv_sum + & | ||
| 52 | − | input%val(i_in + ( c_in - 1 ) * input_h, s) * & | |
| 53 | − | kernel%val(k_idx, 1) | |
| 54 | end if | ||
| 55 | end do | ||
| 56 | end do | ||
| 57 | − | output%val(i + (c_out-1)*output_h, s) = conv_sum | |
| 58 | end do | ||
| 59 | |||
| 60 | ! Store parameters for backward pass | ||
| 61 | − | allocate(output%indices(2)) | |
| 62 | − | output%indices(1) = num_channels | |
| 63 | − | output%indices(2) = num_filters | |
| 64 | − | allocate(output%adj_ja(1,3)) | |
| 65 | − | output%adj_ja(1,1) = stride | |
| 66 | − | output%adj_ja(1,2) = dilation | |
| 67 | − | output%adj_ja(1,3) = kernel_h | |
| 68 | |||
| 69 | − | output%get_partial_left => get_partial_conv1d_input | |
| 70 | − | output%get_partial_right => get_partial_conv1d_kernel | |
| 71 | − | output%get_partial_left_val => get_partial_conv1d_input_val | |
| 72 | − | output%get_partial_right_val => get_partial_conv1d_kernel_val | |
| 73 | − | if(input%requires_grad .or. kernel%requires_grad)then | |
| 74 | − | output%requires_grad = .true. | |
| 75 | − | output%is_forward = input%is_forward | |
| 76 | − | output%operation = 'conv1d' | |
| 77 | − | output%left_operand => input | |
| 78 | − | output%right_operand => kernel | |
| 79 | end if | ||
| 80 | |||
| 81 | − | end function conv1d | |
| 82 | !------------------------------------------------------------------------------- | ||
| 83 | − | function get_partial_conv1d_input(this, upstream_grad) result(output) | |
| 84 | !! Get partial derivative wrt input for 1D convolution | ||
| 85 | implicit none | ||
| 86 | |||
| 87 | class(array_type), intent(inout) :: this | ||
| 88 | type(array_type), intent(in) :: upstream_grad | ||
| 89 | type(array_type) :: output | ||
| 90 | |||
| 91 | − | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 92 | − | size(this%left_operand%val, dim=2) ]) | |
| 93 | − | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 94 | |||
| 95 | − | end function get_partial_conv1d_input | |
| 96 | !------------------------------------------------------------------------------- | ||
| 97 | − | function get_partial_conv1d_kernel(this, upstream_grad) result(output) | |
| 98 | !! Get partial derivative wrt kernel for 1D convolution | ||
| 99 | implicit none | ||
| 100 | |||
| 101 | class(array_type), intent(inout) :: this | ||
| 102 | type(array_type), intent(in) :: upstream_grad | ||
| 103 | type(array_type) :: output | ||
| 104 | |||
| 105 | − | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 106 | − | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 107 | |||
| 108 | − | end function get_partial_conv1d_kernel | |
| 109 | !------------------------------------------------------------------------------- | ||
| 110 | − | pure subroutine get_partial_conv1d_input_val(this, upstream_grad, output) | |
| 111 | !! Get partial derivative wrt input for 1D convolution (subroutine version) | ||
| 112 | implicit none | ||
| 113 | |||
| 114 | class(array_type), intent(in) :: this | ||
| 115 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 116 | real(real32), dimension(:,:), intent(out) :: output | ||
| 117 | |||
| 118 | ! Local variables | ||
| 119 | integer :: i, k, c_in, c_out, s | ||
| 120 | integer :: i_in, k_idx, out_idx | ||
| 121 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 122 | integer :: stride, dilation | ||
| 123 | real(real32) :: grad_val, kernel_val | ||
| 124 | |||
| 125 | |||
| 126 | ! Unpack parameters | ||
| 127 | − | num_channels = this%indices(1) | |
| 128 | − | num_filters = this%indices(2) | |
| 129 | − | stride = this%adj_ja(1,1) | |
| 130 | − | dilation = this%adj_ja(1,2) | |
| 131 | − | kernel_h = this%adj_ja(1,3) | |
| 132 | |||
| 133 | − | input_h = this%left_operand%shape(1) | |
| 134 | − | output_h = this%shape(1) | |
| 135 | |||
| 136 | − | output = 0._real32 | |
| 137 | |||
| 138 | ! Parallelised over batch, channels, and output positions | ||
| 139 | − | do concurrent(s = 1:size(upstream_grad, dim=2), c_in = 1:num_channels, & | |
| 140 | − | i = 1:output_h, c_out = 1:num_filters) | |
| 141 | − | out_idx = i + (c_out-1)*output_h | |
| 142 | − | grad_val = upstream_grad(out_idx, s) | |
| 143 | |||
| 144 | − | if(abs(grad_val) > 1.e-30_real32)then | |
| 145 | − | do k = 1, kernel_h | |
| 146 | − | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 147 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 148 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 149 | − | ( c_out - 1 ) * kernel_h * num_channels | |
| 150 | − | kernel_val = this%right_operand%val(k_idx, 1) | |
| 151 | − | output(i_in + ( c_in - 1 ) * input_h, s) = & | |
| 152 | − | output(i_in + ( c_in - 1 ) * input_h, s) + & | |
| 153 | − | grad_val * kernel_val | |
| 154 | end if | ||
| 155 | end do | ||
| 156 | end if | ||
| 157 | end do | ||
| 158 | |||
| 159 | − | end subroutine get_partial_conv1d_input_val | |
| 160 | !------------------------------------------------------------------------------- | ||
| 161 | − | pure subroutine get_partial_conv1d_kernel_val(this, upstream_grad, output) | |
| 162 | !! Get partial derivative wrt kernel for 1D convolution (subroutine version) | ||
| 163 | implicit none | ||
| 164 | |||
| 165 | class(array_type), intent(in) :: this | ||
| 166 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 167 | real(real32), dimension(:,:), intent(out) :: output | ||
| 168 | |||
| 169 | ! Local variables | ||
| 170 | integer :: i, k, c_in, c_out, s | ||
| 171 | integer :: i_in, k_idx, out_idx | ||
| 172 | integer :: input_h, kernel_h, output_h, num_channels, num_filters | ||
| 173 | integer :: stride, dilation | ||
| 174 | real(real32) :: grad_sum | ||
| 175 | |||
| 176 | |||
| 177 | ! Unpack parameters | ||
| 178 | − | num_channels = this%indices(1) | |
| 179 | − | num_filters = this%indices(2) | |
| 180 | − | stride = this%adj_ja(1,1) | |
| 181 | − | dilation = this%adj_ja(1,2) | |
| 182 | − | kernel_h = this%adj_ja(1,3) | |
| 183 | |||
| 184 | − | input_h = this%left_operand%shape(1) | |
| 185 | − | output_h = this%shape(1) | |
| 186 | |||
| 187 | − | output = 0._real32 | |
| 188 | |||
| 189 | ! Parallelised over filters, channels, and kernel positions | ||
| 190 | − | do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, k = 1:kernel_h) | |
| 191 | k_idx = k + ( c_in - 1 ) * kernel_h + & | ||
| 192 | − | ( c_out - 1 ) * kernel_h * num_channels | |
| 193 | |||
| 194 | − | grad_sum = 0._real32 | |
| 195 | − | do s = 1, size(upstream_grad, dim=2) | |
| 196 | − | do i = 1, output_h | |
| 197 | − | i_in = ( i - 1 ) * stride + ( k - 1 ) * dilation + 1 | |
| 198 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 199 | − | out_idx = i + ( c_out - 1 ) * output_h | |
| 200 | − | grad_sum = grad_sum + upstream_grad(out_idx, s) * & | |
| 201 | − | this%left_operand%val(i_in + ( c_in - 1 ) * input_h, s) | |
| 202 | end if | ||
| 203 | end do | ||
| 204 | end do | ||
| 205 | − | output(k_idx, 1) = grad_sum | |
| 206 | end do | ||
| 207 | |||
| 208 | − | end subroutine get_partial_conv1d_kernel_val | |
| 209 | !############################################################################### | ||
| 210 | |||
| 211 | |||
| 212 | !############################################################################### | ||
| 213 | − | module function conv2d(input, kernel, stride, dilation) result(output) | |
| 214 | !! 2D convolution operation | ||
| 215 | implicit none | ||
| 216 | |||
| 217 | ! Arguments | ||
| 218 | type(array_type), intent(in), target :: input | ||
| 219 | type(array_type), intent(in), target :: kernel | ||
| 220 | integer, dimension(2), intent(in) :: stride | ||
| 221 | integer, dimension(2), intent(in) :: dilation | ||
| 222 | type(array_type), pointer :: output | ||
| 223 | |||
| 224 | ! Local variables | ||
| 225 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 226 | integer :: i_in, j_in, k_idx, out_idx, in_idx, in_base_idx, k_base_idx | ||
| 227 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 228 | integer :: output_h, output_w, num_channels, num_filters | ||
| 229 | integer :: channel_size_in, channel_size_out, kernel_channel_size | ||
| 230 | integer :: dil_kernel_h_m1, dil_kernel_w_m1 | ||
| 231 | real(real32) :: conv_sum | ||
| 232 | integer, dimension(4) :: output_shape | ||
| 233 | |||
| 234 | ! Extract dimensions | ||
| 235 | ! input: [H_in, W_in, C_in, B] | ||
| 236 | ! kernel: [K_h, K_w, C_in, C_out] | ||
| 237 | − | input_h = input%shape(1) | |
| 238 | − | input_w = input%shape(2) | |
| 239 | − | num_channels = input%shape(3) | |
| 240 | − | kernel_h = kernel%shape(1) | |
| 241 | − | kernel_w = kernel%shape(2) | |
| 242 | − | num_filters = kernel%shape(4) | |
| 243 | |||
| 244 | ! Pre-compute common values | ||
| 245 | − | channel_size_in = input_h * input_w | |
| 246 | − | kernel_channel_size = kernel_h * kernel_w | |
| 247 | − | dil_kernel_h_m1 = dilation(1) * (kernel_h - 1) | |
| 248 | − | dil_kernel_w_m1 = dilation(2) * (kernel_w - 1) | |
| 249 | |||
| 250 | ! Calculate output dimensions | ||
| 251 | − | output_h = (input_h - dil_kernel_h_m1 - 1) / stride(1) + 1 | |
| 252 | − | output_w = (input_w - dil_kernel_w_m1 - 1) / stride(2) + 1 | |
| 253 | output_shape = [output_h, output_w, num_filters, & | ||
| 254 | − | size(input%val, dim=2)] | |
| 255 | |||
| 256 | − | output => input%create_result(array_shape = output_shape) | |
| 257 | − | output%val = 0._real32 | |
| 258 | |||
| 259 | − | channel_size_out = output_h * output_w | |
| 260 | |||
| 261 | ! Perform convolution | ||
| 262 | do concurrent(s = 1:output_shape(4), c_out = 1:num_filters, & | ||
| 263 | − | j = 1:output_w, i = 1:output_h) | |
| 264 | − | conv_sum = 0._real32 | |
| 265 | − | do c_in = 1, num_channels | |
| 266 | − | in_base_idx = (c_in - 1) * channel_size_in | |
| 267 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 268 | − | (c_out - 1) * kernel_channel_size * num_channels | |
| 269 | − | do kj = 1, kernel_w | |
| 270 | − | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 271 | − | if(j_in .ge. 1 .and. j_in .le. input_w)then | |
| 272 | − | do ki = 1, kernel_h | |
| 273 | − | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 274 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 275 | − | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 276 | − | k_idx = ki + (kj - 1) * kernel_h + k_base_idx | |
| 277 | − | conv_sum = conv_sum + input%val(in_idx, s) * & | |
| 278 | − | kernel%val(k_idx, 1) | |
| 279 | end if | ||
| 280 | end do | ||
| 281 | end if | ||
| 282 | end do | ||
| 283 | end do | ||
| 284 | − | out_idx = i + (j - 1) * output_h + (c_out - 1) * channel_size_out | |
| 285 | − | output%val(out_idx, s) = conv_sum | |
| 286 | end do | ||
| 287 | |||
| 288 | ! Store parameters for backward pass | ||
| 289 | − | allocate(output%indices(2)) | |
| 290 | − | output%indices(1) = num_channels | |
| 291 | − | output%indices(2) = num_filters | |
| 292 | − | allocate(output%adj_ja(2,3)) | |
| 293 | − | output%adj_ja(1:2,1) = stride | |
| 294 | − | output%adj_ja(1:2,2) = dilation | |
| 295 | − | output%adj_ja(1,3) = kernel_h | |
| 296 | − | output%adj_ja(2,3) = kernel_w | |
| 297 | |||
| 298 | |||
| 299 | − | output%get_partial_left => get_partial_conv2d_input | |
| 300 | − | output%get_partial_right => get_partial_conv2d_kernel | |
| 301 | − | output%get_partial_left_val => get_partial_conv2d_input_val | |
| 302 | − | output%get_partial_right_val => get_partial_conv2d_kernel_val | |
| 303 | − | if(input%requires_grad .or. kernel%requires_grad)then | |
| 304 | − | output%requires_grad = .true. | |
| 305 | − | output%is_forward = input%is_forward | |
| 306 | − | output%operation = 'conv2d' | |
| 307 | − | output%left_operand => input | |
| 308 | − | output%right_operand => kernel | |
| 309 | end if | ||
| 310 | |||
| 311 | − | end function conv2d | |
| 312 | !------------------------------------------------------------------------------- | ||
| 313 | − | function get_partial_conv2d_input(this, upstream_grad) result(output) | |
| 314 | !! Get partial derivative wrt input for 2D convolution | ||
| 315 | implicit none | ||
| 316 | |||
| 317 | class(array_type), intent(inout) :: this | ||
| 318 | type(array_type), intent(in) :: upstream_grad | ||
| 319 | type(array_type) :: output | ||
| 320 | |||
| 321 | − | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 322 | − | size(this%left_operand%val, dim=2) ]) | |
| 323 | − | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 324 | |||
| 325 | − | end function get_partial_conv2d_input | |
| 326 | !------------------------------------------------------------------------------- | ||
| 327 | − | function get_partial_conv2d_kernel(this, upstream_grad) result(output) | |
| 328 | !! Get partial derivative wrt kernel for 2D convolution | ||
| 329 | implicit none | ||
| 330 | |||
| 331 | class(array_type), intent(inout) :: this | ||
| 332 | type(array_type), intent(in) :: upstream_grad | ||
| 333 | type(array_type) :: output | ||
| 334 | |||
| 335 | − | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 336 | − | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 337 | |||
| 338 | − | end function get_partial_conv2d_kernel | |
| 339 | !------------------------------------------------------------------------------- | ||
| 340 | − | pure subroutine get_partial_conv2d_input_val(this, upstream_grad, output) | |
| 341 | !! Get partial derivative wrt input for 2D convolution (subroutine version) | ||
| 342 | implicit none | ||
| 343 | |||
| 344 | class(array_type), intent(in) :: this | ||
| 345 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 346 | real(real32), dimension(:,:), intent(out) :: output | ||
| 347 | |||
| 348 | ! Local variables | ||
| 349 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 350 | integer :: i_in, j_in, k_idx, out_idx, in_idx | ||
| 351 | integer :: in_base_idx, k_base_idx, kernel_channel_size | ||
| 352 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 353 | integer :: output_h, output_w, num_channels, num_filters | ||
| 354 | integer, dimension(2) :: stride, dilation | ||
| 355 | integer :: channel_size_in, channel_size_out | ||
| 356 | real(real32) :: grad_val, kernel_val | ||
| 357 | |||
| 358 | |||
| 359 | ! Unpack parameters | ||
| 360 | − | num_channels = this%indices(1) | |
| 361 | − | num_filters = this%indices(2) | |
| 362 | − | stride = this%adj_ja(1:2,1) | |
| 363 | − | dilation = this%adj_ja(1:2,2) | |
| 364 | − | kernel_h = this%adj_ja(1,3) | |
| 365 | − | kernel_w = this%adj_ja(2,3) | |
| 366 | |||
| 367 | − | input_h = this%left_operand%shape(1) | |
| 368 | − | input_w = this%left_operand%shape(2) | |
| 369 | − | output_h = this%shape(1) | |
| 370 | − | output_w = this%shape(2) | |
| 371 | − | channel_size_in = input_h * input_w | |
| 372 | − | channel_size_out = output_h * output_w | |
| 373 | − | kernel_channel_size = kernel_h * kernel_w | |
| 374 | |||
| 375 | − | output = 0._real32 | |
| 376 | |||
| 377 | ! Parallelised over batch, output channels, output spatial dims | ||
| 378 | − | do concurrent(s = 1:size(upstream_grad, dim=2), c_out = 1:num_filters, & | |
| 379 | − | j = 1:output_w, i = 1:output_h) | |
| 380 | − | out_idx = i + (j-1)*output_h + (c_out-1)*channel_size_out | |
| 381 | − | grad_val = upstream_grad(out_idx, s) | |
| 382 | |||
| 383 | − | if(abs(grad_val) .gt. 1.e-30_real32)then | |
| 384 | − | do c_in = 1, num_channels | |
| 385 | − | in_base_idx = (c_in - 1) * channel_size_in | |
| 386 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 387 | − | (c_out - 1) * kernel_channel_size * num_channels | |
| 388 | |||
| 389 | − | do kj = 1, kernel_w | |
| 390 | − | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 391 | − | if(j_in .ge. 1 .and. j_in .le. input_w)then | |
| 392 | − | do ki = 1, kernel_h | |
| 393 | − | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 394 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 395 | − | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 396 | k_idx = (kernel_h - ki + 1) + & | ||
| 397 | − | (kernel_w - kj) * kernel_h + k_base_idx | |
| 398 | − | kernel_val = this%right_operand%val(k_idx, 1) | |
| 399 | − | output(in_idx, s) = output(in_idx, s) + & | |
| 400 | − | grad_val * kernel_val | |
| 401 | end if | ||
| 402 | end do | ||
| 403 | end if | ||
| 404 | end do | ||
| 405 | end do | ||
| 406 | end if | ||
| 407 | end do | ||
| 408 | |||
| 409 | − | end subroutine get_partial_conv2d_input_val | |
| 410 | !------------------------------------------------------------------------------- | ||
| 411 | − | pure subroutine get_partial_conv2d_kernel_val(this, upstream_grad, output) | |
| 412 | !! Get partial derivative wrt kernel for 2D convolution (subroutine version) | ||
| 413 | implicit none | ||
| 414 | |||
| 415 | class(array_type), intent(in) :: this | ||
| 416 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 417 | real(real32), dimension(:,:), intent(out) :: output | ||
| 418 | |||
| 419 | ! Local variables | ||
| 420 | integer :: i, j, ki, kj, c_in, c_out, s | ||
| 421 | integer :: i_in, j_in, k_idx, out_idx, in_idx | ||
| 422 | integer :: in_base_idx, out_base_idx, k_base_idx, kernel_channel_size | ||
| 423 | integer :: input_h, input_w, kernel_h, kernel_w | ||
| 424 | integer :: output_h, output_w, num_channels, num_filters | ||
| 425 | integer, dimension(2) :: stride, dilation | ||
| 426 | integer :: channel_size_in, channel_size_out | ||
| 427 | real(real32) :: grad_sum | ||
| 428 | |||
| 429 | |||
| 430 | ! Unpack parameters | ||
| 431 | − | num_channels = this%indices(1) | |
| 432 | − | num_filters = this%indices(2) | |
| 433 | − | stride = this%adj_ja(1:2,1) | |
| 434 | − | dilation = this%adj_ja(1:2,2) | |
| 435 | − | kernel_h = this%adj_ja(1,3) | |
| 436 | − | kernel_w = this%adj_ja(2,3) | |
| 437 | |||
| 438 | − | input_h = this%left_operand%shape(1) | |
| 439 | − | input_w = this%left_operand%shape(2) | |
| 440 | − | output_h = this%shape(1) | |
| 441 | − | output_w = this%shape(2) | |
| 442 | − | channel_size_in = input_h * input_w | |
| 443 | − | channel_size_out = output_h * output_w | |
| 444 | − | kernel_channel_size = kernel_h * kernel_w | |
| 445 | |||
| 446 | − | output = 0._real32 | |
| 447 | |||
| 448 | ! Parallelised over filters, channels, and kernel dimensions | ||
| 449 | do concurrent(c_out = 1:num_filters, c_in = 1:num_channels, & | ||
| 450 | − | kj = 1:kernel_w, ki = 1:kernel_h) | |
| 451 | − | out_base_idx = (c_out - 1) * channel_size_out | |
| 452 | − | in_base_idx = (c_in - 1) * channel_size_in | |
| 453 | k_base_idx = (c_in - 1) * kernel_channel_size + & | ||
| 454 | − | (c_out - 1) * kernel_channel_size * num_channels | |
| 455 | − | k_idx = ki + (kj - 1) * kernel_h + k_base_idx | |
| 456 | |||
| 457 | − | grad_sum = 0._real32 | |
| 458 | − | do s = 1, size(upstream_grad, dim=2) | |
| 459 | − | do j = 1, output_w | |
| 460 | − | j_in = (j - 1) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 461 | − | if(j_in .ge. 1 .and. j_in .le. input_w)then | |
| 462 | − | do i = 1, output_h | |
| 463 | − | i_in = (i - 1) * stride(1) + (ki - 1) * dilation(1) + 1 | |
| 464 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 465 | − | in_idx = i_in + (j_in - 1) * input_h + in_base_idx | |
| 466 | − | out_idx = i + (j - 1) * output_h + out_base_idx | |
| 467 | grad_sum = grad_sum + & | ||
| 468 | − | upstream_grad(out_idx, s) * this%left_operand%val(in_idx, s) | |
| 469 | end if | ||
| 470 | end do | ||
| 471 | end if | ||
| 472 | end do | ||
| 473 | end do | ||
| 474 | − | output(k_idx, 1) = grad_sum | |
| 475 | end do | ||
| 476 | |||
| 477 | − | end subroutine get_partial_conv2d_kernel_val | |
| 478 | !############################################################################### | ||
| 479 | |||
| 480 | |||
| 481 | !############################################################################### | ||
| 482 | − | module function conv3d(input, kernel, stride, dilation) result(output) | |
| 483 | !! 3D convolution operation | ||
| 484 | implicit none | ||
| 485 | |||
| 486 | ! Arguments | ||
| 487 | type(array_type), intent(in), target :: input | ||
| 488 | type(array_type), intent(in), target :: kernel | ||
| 489 | integer, dimension(3), intent(in) :: stride | ||
| 490 | integer, dimension(3), intent(in) :: dilation | ||
| 491 | type(array_type), pointer :: output | ||
| 492 | |||
| 493 | ! Local variables | ||
| 494 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 495 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 496 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 497 | integer :: output_h, output_w, output_d | ||
| 498 | integer :: num_channels, num_filters | ||
| 499 | integer :: channel_size_in, channel_size_out | ||
| 500 | real(real32) :: conv_sum | ||
| 501 | integer, dimension(5) :: output_shape | ||
| 502 | |||
| 503 | ! Extract dimensions | ||
| 504 | ! input: [H_in, W_in, D_in, C_in, B] | ||
| 505 | ! kernel: [K_h, K_w, K_d, C_in, C_out] | ||
| 506 | − | input_h = input%shape(1) | |
| 507 | − | input_w = input%shape(2) | |
| 508 | − | input_d = input%shape(3) | |
| 509 | − | num_channels = input%shape(4) | |
| 510 | − | kernel_h = kernel%shape(1) | |
| 511 | − | kernel_w = kernel%shape(2) | |
| 512 | − | kernel_d = kernel%shape(3) | |
| 513 | − | num_filters = kernel%shape(5) | |
| 514 | |||
| 515 | ! Calculate output dimensions | ||
| 516 | output_h = (input_h - dilation(1)*(kernel_h - 1) - 1) / & | ||
| 517 | − | stride(1) + 1 | |
| 518 | output_w = (input_w - dilation(2)*(kernel_w - 1) - 1) / & | ||
| 519 | − | stride(2) + 1 | |
| 520 | output_d = (input_d - dilation(3)*(kernel_d - 1) - 1) / & | ||
| 521 | − | stride(3) + 1 | |
| 522 | output_shape = [output_h, output_w, output_d, num_filters, & | ||
| 523 | − | size(input%val, dim=2)] | |
| 524 | |||
| 525 | − | output => input%create_result(array_shape = output_shape) | |
| 526 | − | output%val = 0._real32 | |
| 527 | |||
| 528 | − | channel_size_in = input_h * input_w * input_d | |
| 529 | − | channel_size_out = output_h * output_w * output_d | |
| 530 | |||
| 531 | ! Perform convolution - optimised with do concurrent | ||
| 532 | do concurrent(s = 1:output_shape(5), c_out = 1:num_filters, & | ||
| 533 | − | k = 1:output_d, j = 1:output_w, i = 1:output_h) | |
| 534 | |||
| 535 | − | conv_sum = 0._real32 | |
| 536 | − | do c_in = 1, num_channels | |
| 537 | − | do kk = 1, kernel_d | |
| 538 | − | k_in = ( k - 1 ) * stride(3) + ( kk - 1 ) * dilation(3) + 1 | |
| 539 | − | if(k_in .ge. 1 .and. k_in .le. input_d)then | |
| 540 | − | do kj = 1, kernel_w | |
| 541 | − | j_in = ( j - 1 ) * stride(2) + (kj - 1) * dilation(2) + 1 | |
| 542 | − | if(j_in .ge. 1 .and. j_in .le. input_w)then | |
| 543 | − | do ki = 1, kernel_h | |
| 544 | i_in = ( i - 1 ) * stride(1) + & | ||
| 545 | − | ( ki - 1 ) * dilation(1) + 1 | |
| 546 | − | if(i_in .ge. 1 .and. i_in .le. input_h)then | |
| 547 | in_idx = i_in + ( j_in - 1 ) * input_h + & | ||
| 548 | ( k_in - 1 ) * input_h * input_w + & | ||
| 549 | − | ( c_in - 1 ) * channel_size_in | |
| 550 | k_idx = ki + ( kj - 1 ) * kernel_h + & | ||
| 551 | ( kk - 1 ) * kernel_h * kernel_w + & | ||
| 552 | ( c_in - 1 ) * kernel_h * kernel_w * & | ||
| 553 | kernel_d + & | ||
| 554 | ( c_out - 1 ) * kernel_h * kernel_w * & | ||
| 555 | − | kernel_d * num_channels | |
| 556 | − | conv_sum = conv_sum + input%val(in_idx, s) * & | |
| 557 | − | kernel%val(k_idx, 1) | |
| 558 | end if | ||
| 559 | end do | ||
| 560 | end if | ||
| 561 | end do | ||
| 562 | end if | ||
| 563 | end do | ||
| 564 | end do | ||
| 565 | out_idx = i + ( j - 1 ) * output_h + & | ||
| 566 | ( k - 1 ) * output_h * output_w + & | ||
| 567 | − | ( c_out - 1 ) * channel_size_out | |
| 568 | − | output%val(out_idx, s) = conv_sum | |
| 569 | end do | ||
| 570 | |||
| 571 | ! Store parameters for backward pass | ||
| 572 | − | allocate(output%indices(2)) | |
| 573 | − | output%indices(1) = num_channels | |
| 574 | − | output%indices(2) = num_filters | |
| 575 | − | allocate(output%adj_ja(3,3)) | |
| 576 | − | output%adj_ja(1:3,1) = stride | |
| 577 | − | output%adj_ja(1:3,2) = dilation | |
| 578 | − | output%adj_ja(1,3) = kernel_h | |
| 579 | − | output%adj_ja(2,3) = kernel_w | |
| 580 | − | output%adj_ja(3,3) = kernel_d | |
| 581 | |||
| 582 | − | output%get_partial_left => get_partial_conv3d_input | |
| 583 | − | output%get_partial_right => get_partial_conv3d_kernel | |
| 584 | − | output%get_partial_left_val => get_partial_conv3d_input_val | |
| 585 | − | output%get_partial_right_val => get_partial_conv3d_kernel_val | |
| 586 | − | if(input%requires_grad .or. kernel%requires_grad)then | |
| 587 | − | output%requires_grad = .true. | |
| 588 | − | output%is_forward = input%is_forward | |
| 589 | − | output%operation = 'conv3d' | |
| 590 | − | output%left_operand => input | |
| 591 | − | output%right_operand => kernel | |
| 592 | end if | ||
| 593 | |||
| 594 | − | end function conv3d | |
| 595 | !------------------------------------------------------------------------------- | ||
| 596 | − | function get_partial_conv3d_input(this, upstream_grad) result(output) | |
| 597 | !! Get partial derivative wrt input for 3D convolution | ||
| 598 | implicit none | ||
| 599 | |||
| 600 | class(array_type), intent(inout) :: this | ||
| 601 | type(array_type), intent(in) :: upstream_grad | ||
| 602 | type(array_type) :: output | ||
| 603 | |||
| 604 | − | call output%allocate(array_shape = [ this%left_operand%shape, & | |
| 605 | − | size(this%left_operand%val, dim=2) ]) | |
| 606 | − | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 607 | |||
| 608 | − | end function get_partial_conv3d_input | |
| 609 | !------------------------------------------------------------------------------- | ||
| 610 | − | function get_partial_conv3d_kernel(this, upstream_grad) result(output) | |
| 611 | !! Get partial derivative wrt kernel for 3D convolution | ||
| 612 | implicit none | ||
| 613 | |||
| 614 | class(array_type), intent(inout) :: this | ||
| 615 | type(array_type), intent(in) :: upstream_grad | ||
| 616 | type(array_type) :: output | ||
| 617 | |||
| 618 | − | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 619 | − | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 620 | |||
| 621 | − | end function get_partial_conv3d_kernel | |
| 622 | !------------------------------------------------------------------------------- | ||
| 623 | − | pure subroutine get_partial_conv3d_input_val(this, upstream_grad, output) | |
| 624 | !! Get partial derivative wrt input for 3D convolution (subroutine version) | ||
| 625 | implicit none | ||
| 626 | |||
| 627 | class(array_type), intent(in) :: this | ||
| 628 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 629 | real(real32), dimension(:,:), intent(out) :: output | ||
| 630 | |||
| 631 | ! Local variables | ||
| 632 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 633 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 634 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 635 | integer :: output_h, output_w, output_d | ||
| 636 | integer :: num_channels, num_filters | ||
| 637 | integer :: channel_size_in, channel_size_out | ||
| 638 | integer, dimension(3) :: stride, dilation | ||
| 639 | real(real32) :: grad_val, kernel_val | ||
| 640 | |||
| 641 | |||
| 642 | ! Unpack parameters | ||
| 643 | − | num_channels = this%indices(1) | |
| 644 | − | num_filters = this%indices(2) | |
| 645 | − | stride = this%adj_ja(1:3,1) | |
| 646 | − | dilation = this%adj_ja(1:3,2) | |
| 647 | − | kernel_h = this%adj_ja(1,3) | |
| 648 | − | kernel_w = this%adj_ja(2,3) | |
| 649 | − | kernel_d = this%adj_ja(3,3) | |
| 650 | |||
| 651 | − | input_h = this%left_operand%shape(1) | |
| 652 | − | input_w = this%left_operand%shape(2) | |
| 653 | − | input_d = this%left_operand%shape(3) | |
| 654 | − | output_h = this%shape(1) | |
| 655 | − | output_w = this%shape(2) | |
| 656 | − | output_d = this%shape(3) | |
| 657 | |||
| 658 | − | output = 0._real32 | |
| 659 | |||
| 660 | − | channel_size_in = input_h * input_w * input_d | |
| 661 | − | channel_size_out = output_h * output_w * output_d | |
| 662 | |||
| 663 | − | do s = 1, size(upstream_grad, dim=2) | |
| 664 | − | do c_in = 1, num_channels | |
| 665 | − | do k = 1, output_d | |
| 666 | − | do j = 1, output_w | |
| 667 | − | do i = 1, output_h | |
| 668 | − | do c_out = 1, num_filters | |
| 669 | out_idx = i + ( j - 1 ) * output_h + & | ||
| 670 | ( k - 1 ) * output_h * output_w + & | ||
| 671 | − | ( c_out - 1 ) * channel_size_out | |
| 672 | − | grad_val = upstream_grad(out_idx, s) | |
| 673 | |||
| 674 | − | do kk = 1, kernel_d | |
| 675 | k_in = ( k - 1 ) * stride(3) + & | ||
| 676 | − | ( kk - 1 ) * dilation(3) + 1 | |
| 677 | − | if( k_in .ge. 1 .and. k_in .le. input_d )then | |
| 678 | − | do kj = 1, kernel_w | |
| 679 | j_in = ( j - 1 ) * stride(2) + & | ||
| 680 | − | ( kj - 1 ) * dilation(2) + 1 | |
| 681 | − | if( j_in .ge. 1 .and. j_in .le. input_w )then | |
| 682 | − | do ki = 1, kernel_h | |
| 683 | i_in = ( i - 1 ) * stride(1) + & | ||
| 684 | − | ( ki - 1 ) * dilation(1) + 1 | |
| 685 | − | if( i_in .ge. 1 .and. & | |
| 686 | − | i_in .le. input_h )then | |
| 687 | in_idx = i_in + & | ||
| 688 | ( j_in - 1 ) * input_h + & | ||
| 689 | ( k_in - 1 ) * input_h * input_w + & | ||
| 690 | − | ( c_in - 1 ) * channel_size_in | |
| 691 | k_idx = ki + ( kj - 1 ) * kernel_h + & | ||
| 692 | ( kk - 1 ) * kernel_h * kernel_w + & | ||
| 693 | ( c_in - 1 ) * kernel_h * & | ||
| 694 | kernel_w * kernel_d + & | ||
| 695 | ( c_out - 1 ) * kernel_h * & | ||
| 696 | − | kernel_w * kernel_d * num_channels | |
| 697 | − | kernel_val = this%right_operand%val(k_idx, 1) | |
| 698 | − | output(in_idx, s) = & | |
| 699 | − | output(in_idx, s) + & | |
| 700 | − | grad_val * kernel_val | |
| 701 | end if | ||
| 702 | end do | ||
| 703 | end if | ||
| 704 | end do | ||
| 705 | end if | ||
| 706 | end do | ||
| 707 | end do | ||
| 708 | end do | ||
| 709 | end do | ||
| 710 | end do | ||
| 711 | end do | ||
| 712 | end do | ||
| 713 | |||
| 714 | − | end subroutine get_partial_conv3d_input_val | |
| 715 | !------------------------------------------------------------------------------- | ||
| 716 | − | pure subroutine get_partial_conv3d_kernel_val(this, upstream_grad, output) | |
| 717 | !! Get partial derivative wrt kernel for 3D convolution (subroutine version) | ||
| 718 | implicit none | ||
| 719 | |||
| 720 | class(array_type), intent(in) :: this | ||
| 721 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 722 | real(real32), dimension(:,:), intent(out) :: output | ||
| 723 | |||
| 724 | ! Local variables | ||
| 725 | integer :: i, j, k, ki, kj, kk, c_in, c_out, s | ||
| 726 | integer :: i_in, j_in, k_in, k_idx, out_idx, in_idx | ||
| 727 | integer :: input_h, input_w, input_d, kernel_h, kernel_w, kernel_d | ||
| 728 | integer :: output_h, output_w, output_d | ||
| 729 | integer :: num_channels, num_filters | ||
| 730 | integer :: channel_size_in, channel_size_out | ||
| 731 | integer, dimension(3) :: stride, dilation | ||
| 732 | real(real32) :: grad_sum | ||
| 733 | |||
| 734 | |||
| 735 | ! Unpack parameters | ||
| 736 | − | num_channels = this%indices(1) | |
| 737 | − | num_filters = this%indices(2) | |
| 738 | − | stride = this%adj_ja(1:3,1) | |
| 739 | − | dilation = this%adj_ja(1:3,2) | |
| 740 | − | kernel_h = this%adj_ja(1,3) | |
| 741 | − | kernel_w = this%adj_ja(2,3) | |
| 742 | − | kernel_d = this%adj_ja(3,3) | |
| 743 | |||
| 744 | − | input_h = this%left_operand%shape(1) | |
| 745 | − | input_w = this%left_operand%shape(2) | |
| 746 | − | input_d = this%left_operand%shape(3) | |
| 747 | − | output_h = this%shape(1) | |
| 748 | − | output_w = this%shape(2) | |
| 749 | − | output_d = this%shape(3) | |
| 750 | |||
| 751 | − | output = 0._real32 | |
| 752 | |||
| 753 | − | channel_size_in = input_h * input_w * input_d | |
| 754 | − | channel_size_out = output_h * output_w * output_d | |
| 755 | |||
| 756 | − | do c_out = 1, num_filters | |
| 757 | − | do c_in = 1, num_channels | |
| 758 | − | do kk = 1, kernel_d | |
| 759 | − | do kj = 1, kernel_w | |
| 760 | − | do ki = 1, kernel_h | |
| 761 | k_idx = ki + (kj-1)*kernel_h + & | ||
| 762 | (kk-1)*kernel_h*kernel_w + & | ||
| 763 | (c_in-1)*kernel_h*kernel_w*kernel_d + & | ||
| 764 | − | (c_out-1)*kernel_h*kernel_w*kernel_d*num_channels | |
| 765 | |||
| 766 | − | grad_sum = 0._real32 | |
| 767 | − | do s = 1, size(upstream_grad, dim=2) | |
| 768 | − | do k = 1, output_d | |
| 769 | − | k_in = (k-1)*stride(3) + (kk-1)*dilation(3) + 1 | |
| 770 | − | if(k_in >= 1 .and. k_in <= input_d)then | |
| 771 | − | do j = 1, output_w | |
| 772 | j_in = (j-1)*stride(2) + & | ||
| 773 | − | (kj-1)*dilation(2) + 1 | |
| 774 | − | if(j_in >= 1 .and. j_in <= input_w)then | |
| 775 | − | do i = 1, output_h | |
| 776 | i_in = (i-1)*stride(1) + & | ||
| 777 | − | (ki-1)*dilation(1) + 1 | |
| 778 | − | if(i_in >= 1 .and. i_in <= input_h)then | |
| 779 | in_idx = i_in + (j_in-1)*input_h + & | ||
| 780 | (k_in-1)*input_h*input_w + & | ||
| 781 | − | (c_in-1)*channel_size_in | |
| 782 | out_idx = i + (j-1)*output_h + & | ||
| 783 | (k-1)*output_h*output_w + & | ||
| 784 | − | (c_out-1)*channel_size_out | |
| 785 | grad_sum = grad_sum + & | ||
| 786 | − | upstream_grad(out_idx, s) * & | |
| 787 | − | this%left_operand%val(in_idx, s) | |
| 788 | end if | ||
| 789 | end do | ||
| 790 | end if | ||
| 791 | end do | ||
| 792 | end if | ||
| 793 | end do | ||
| 794 | end do | ||
| 795 | − | output(k_idx, 1) = grad_sum | |
| 796 | end do | ||
| 797 | end do | ||
| 798 | end do | ||
| 799 | end do | ||
| 800 | end do | ||
| 801 | |||
| 802 | − | end subroutine get_partial_conv3d_kernel_val | |
| 803 | !############################################################################### | ||
| 804 | |||
| 805 | end submodule athena__diffstruc_extd_submodule_conv | ||
| 806 |