| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | use coreutils, only: stop_program | ||
| 4 | use diffstruc, only: & | ||
| 5 | operator(+), operator(-), operator(*), concat, exp, sum, merge | ||
| 6 | |||
| 7 | contains | ||
| 8 | |||
| 9 | !############################################################################### | ||
| 10 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | module function add_array_ptr(a, idx1, idx2) result(c) |
| 11 | !! Add two autodiff arrays | ||
| 12 | implicit none | ||
| 13 | |||
| 14 | ! Arguments | ||
| 15 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 16 | integer, intent(in) :: idx1, idx2 | ||
| 17 | type(array_type), pointer :: c | ||
| 18 | |||
| 19 | ! Local variables | ||
| 20 | integer :: i | ||
| 21 | |||
| 22 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
|
1 | c => a(1)%array(idx1, idx2) + a(2)%array(idx1, idx2) |
| 23 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
|
1 | do i = 3, size(a), 1 |
| 24 |
0/12✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
|
1 | c => c + a(i)%array(idx1, idx2) |
| 25 | end do | ||
| 26 | 1 | end function add_array_ptr | |
| 27 | !############################################################################### | ||
| 28 | |||
| 29 | |||
| 30 | !############################################################################### | ||
| 31 |
1/2✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
|
1 | module function concat_array_ptr(a, idx1, idx2, dim) result(c) |
| 32 | !! Concatenate two autodiff arrays along a specified dimension | ||
| 33 | implicit none | ||
| 34 | |||
| 35 | ! Arguments | ||
| 36 | type(array_ptr_type), dimension(:), intent(in) :: a | ||
| 37 | integer, intent(in) :: idx1, idx2, dim | ||
| 38 | type(array_type), pointer :: c | ||
| 39 | |||
| 40 | ! Local variables | ||
| 41 | integer :: i | ||
| 42 | |||
| 43 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
|
1 | c => concat(a(1)%array(idx1, idx2), a(2)%array(idx1, idx2), dim) |
| 44 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
|
1 | do i = 3, size(a), 1 |
| 45 |
0/12✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
|
1 | c => concat(c, a(i)%array(idx1, idx2), dim) |
| 46 | end do | ||
| 47 | 1 | end function concat_array_ptr | |
| 48 | !############################################################################### | ||
| 49 | |||
| 50 | |||
| 51 | !##############################################################################! | ||
| 52 | ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! | ||
| 53 | !##############################################################################! | ||
| 54 | |||
| 55 | |||
| 56 | !############################################################################### | ||
| 57 | 43 | module function add_bias(input, bias, dim, dim_act_on_shape) result(output) | |
| 58 | !! Add bias to input array along specified dimension | ||
| 59 | implicit none | ||
| 60 | |||
| 61 | ! Arguments | ||
| 62 | class(array_type), intent(in), target :: input | ||
| 63 | class(array_type), intent(in), target :: bias | ||
| 64 | integer, intent(in) :: dim | ||
| 65 | logical, intent(in), optional :: dim_act_on_shape | ||
| 66 | type(array_type), pointer :: output | ||
| 67 | |||
| 68 | ! Local variables | ||
| 69 | integer :: i, j, k, s, idx, itmp1 | ||
| 70 | integer :: num_elements_pre, num_elements_post, num_dims | ||
| 71 | logical :: dim_act_on_shape_ | ||
| 72 | |||
| 73 |
1/2✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
|
43 | if(present(dim_act_on_shape))then |
| 74 | 43 | dim_act_on_shape_ = dim_act_on_shape | |
| 75 | else | ||
| 76 | ✗ | dim_act_on_shape_ = .false. | |
| 77 | end if | ||
| 78 | |||
| 79 | 43 | output => input%create_result() | |
| 80 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 43 times.
|
43 | allocate(output%indices(2)) |
| 81 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
|
43 | output%indices(1) = dim |
| 82 |
1/2✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
|
43 | if(dim_act_on_shape_)then |
| 83 | 43 | num_dims = size(input%shape) | |
| 84 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
|
43 | if(dim .gt. num_dims)then |
| 85 | ✗ | call stop_program("Dimension for add_bias exceeds input dimensions") | |
| 86 | ✗ | return | |
| 87 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
|
43 | elseif(size(bias%shape) .ne. 1)then |
| 88 | ✗ | call stop_program("Bias must be a 1D array") | |
| 89 | ✗ | return | |
| 90 | end if | ||
| 91 | 43 | num_elements_pre = 1 | |
| 92 | 43 | num_elements_post = 1 | |
| 93 |
2/2✓ Branch 0 taken 128 times.
✓ Branch 1 taken 43 times.
|
171 | do i = 1, num_dims |
| 94 |
2/2✓ Branch 0 taken 85 times.
✓ Branch 1 taken 43 times.
|
171 | if(i .lt. dim)then |
| 95 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 85 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 85 times.
|
85 | num_elements_pre = num_elements_pre * input%shape(i) |
| 96 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
|
43 | elseif(i .gt. dim)then |
| 97 | ✗ | num_elements_post = num_elements_post * input%shape(i) | |
| 98 | end if | ||
| 99 | end do | ||
| 100 | |||
| 101 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
|
43 | itmp1 = num_elements_pre * input%shape(dim) |
| 102 |
2/2✓ Branch 0 taken 68 times.
✓ Branch 1 taken 43 times.
|
111 | do s = 1, size(input%val, 2) |
| 103 |
2/2✓ Branch 0 taken 68 times.
✓ Branch 1 taken 68 times.
|
179 | do k = 1, num_elements_post |
| 104 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✓ Branch 6 taken 196 times.
✓ Branch 7 taken 68 times.
|
332 | do j = 1, bias%shape(1) |
| 105 | 196 | idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 | |
| 106 |
2/2✓ Branch 0 taken 3095 times.
✓ Branch 1 taken 196 times.
|
3359 | do i = 1, num_elements_pre |
| 107 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 3095 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3095 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3095 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3095 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3095 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3095 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3095 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3095 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3095 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3095 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3095 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3095 times.
|
3291 | output%val(idx + i, s) = input%val(idx + i, s) + bias%val(j,1) |
| 108 | end do | ||
| 109 | end do | ||
| 110 | end do | ||
| 111 | end do | ||
| 112 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
|
43 | output%indices(2) = 1 |
| 113 | else | ||
| 114 | ✗ | call stop_program("add_bias: dim_act_on_shape=.false. not implemented yet") | |
| 115 | ✗ | output%indices(2) = 0 | |
| 116 | end if | ||
| 117 | |||
| 118 | 43 | output%get_partial_left => get_partial_add | |
| 119 | 43 | output%get_partial_right => get_partial_add_bias | |
| 120 | 43 | output%get_partial_left_val => get_partial_add_val | |
| 121 | 43 | output%get_partial_right_val => get_partial_add_bias_val | |
| 122 |
1/2✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
|
43 | if(input%requires_grad .or. bias%requires_grad)then |
| 123 | 43 | output%requires_grad = .true. | |
| 124 | 43 | output%is_forward = input%is_forward .or. bias%is_forward | |
| 125 | 43 | output%operation = 'add_bias' | |
| 126 | 43 | output%left_operand => input | |
| 127 | 43 | output%right_operand => bias | |
| 128 | end if | ||
| 129 | |||
| 130 | 43 | end function add_bias | |
| 131 | !------------------------------------------------------------------------------- | ||
| 132 | ✗ | function get_partial_add(this, upstream_grad) result(output) | |
| 133 | !! Get partial derivative with respect to left operand | ||
| 134 | implicit none | ||
| 135 | class(array_type), intent(inout) :: this | ||
| 136 | type(array_type), intent(in) :: upstream_grad | ||
| 137 | type(array_type) :: output | ||
| 138 | |||
| 139 | ✗ | output = upstream_grad | |
| 140 | ✗ | end function get_partial_add | |
| 141 | !------------------------------------------------------------------------------- | ||
| 142 |
2/4✓ Branch 0 taken 29 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✗ Branch 3 not taken.
|
29 | pure subroutine get_partial_add_val(this, upstream_grad, output) |
| 143 | !! Get partial derivative with respect to left operand | ||
| 144 | implicit none | ||
| 145 | class(array_type), intent(in) :: this | ||
| 146 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 147 | real(real32), dimension(:,:), intent(out) :: output | ||
| 148 | |||
| 149 |
13/26✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
|
29 | if(size(upstream_grad,2).ne.size(output,2))then |
| 150 | ✗ | if(size(output,1).eq.1)then | |
| 151 | ✗ | output(1,1) = sum(upstream_grad) | |
| 152 | else | ||
| 153 | ✗ | output(:,1) = sum(upstream_grad, dim=2) | |
| 154 | end if | ||
| 155 | else | ||
| 156 |
19/38✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 29 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 29 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 29 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 29 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 29 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 29 times.
|
29 | if(size(output,1).eq.1.and.size(output,1).ne.size(upstream_grad,1))then |
| 157 | ✗ | output(1,:) = sum(upstream_grad,1) | |
| 158 | else | ||
| 159 |
18/32✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 29 times.
✓ Branch 42 taken 29 times.
✓ Branch 43 taken 29 times.
✓ Branch 44 taken 1372 times.
✓ Branch 45 taken 29 times.
|
1430 | output = upstream_grad |
| 160 | end if | ||
| 161 | end if | ||
| 162 | 29 | end subroutine get_partial_add_val | |
| 163 | !------------------------------------------------------------------------------- | ||
| 164 | ✗ | function get_partial_add_bias(this, upstream_grad) result(output) | |
| 165 | !! Get partial derivative with respect to bias operand | ||
| 166 | implicit none | ||
| 167 | class(array_type), intent(inout) :: this | ||
| 168 | type(array_type), intent(in) :: upstream_grad | ||
| 169 | type(array_type) :: output | ||
| 170 | |||
| 171 | ✗ | call output%allocate(array_shape = [ this%right_operand%shape, 1 ]) | |
| 172 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 173 | |||
| 174 | ✗ | end function get_partial_add_bias | |
| 175 | !------------------------------------------------------------------------------- | ||
| 176 |
2/4✓ Branch 0 taken 29 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✗ Branch 3 not taken.
|
29 | pure subroutine get_partial_add_bias_val(this, upstream_grad, output) |
| 177 | implicit none | ||
| 178 | |||
| 179 | ! Arguments | ||
| 180 | class(array_type), intent(in) :: this | ||
| 181 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 182 | real(real32), dimension(:,:), intent(out) :: output | ||
| 183 | |||
| 184 | integer :: i, j, k, s, idx, itmp1 | ||
| 185 | integer :: num_elements_pre, num_elements_post, num_dims | ||
| 186 | |||
| 187 | 29 | num_dims = size(this%left_operand%shape) | |
| 188 | 29 | num_elements_pre = 1 | |
| 189 | 29 | num_elements_post = 1 | |
| 190 |
2/2✓ Branch 0 taken 97 times.
✓ Branch 1 taken 29 times.
|
126 | do i = 1, num_dims |
| 191 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 97 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 97 times.
✓ Branch 6 taken 68 times.
✓ Branch 7 taken 29 times.
|
126 | if(i .lt. this%indices(1))then |
| 192 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
|
68 | num_elements_pre = num_elements_pre * this%left_operand%shape(i) |
| 193 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
|
29 | elseif(i .gt. this%indices(1))then |
| 194 | ✗ | num_elements_post = num_elements_post * this%left_operand%shape(i) | |
| 195 | end if | ||
| 196 | end do | ||
| 197 | |||
| 198 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
|
29 | itmp1 = num_elements_pre * this%left_operand%shape(this%indices(1)) |
| 199 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✓ Branch 18 taken 29 times.
✓ Branch 19 taken 29 times.
✓ Branch 20 taken 61 times.
✓ Branch 21 taken 29 times.
|
119 | output = 0._real32 |
| 200 |
8/14✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✓ Branch 18 taken 29 times.
✓ Branch 19 taken 29 times.
|
58 | do s = 1, size(upstream_grad, 2) |
| 201 |
2/2✓ Branch 0 taken 29 times.
✓ Branch 1 taken 29 times.
|
87 | do k = 1, num_elements_post |
| 202 |
4/6✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✓ Branch 6 taken 61 times.
✓ Branch 7 taken 29 times.
|
119 | do j = 1, this%right_operand%shape(1) |
| 203 | 61 | idx = (j - 1) * num_elements_pre + (k - 1) * itmp1 | |
| 204 |
2/2✓ Branch 0 taken 1372 times.
✓ Branch 1 taken 61 times.
|
1462 | do i = 1, num_elements_pre |
| 205 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 1372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1372 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1372 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1372 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1372 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1372 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1372 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1372 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1372 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1372 times.
|
1433 | output(j,1) = output(j,1) + upstream_grad(idx + i, s) |
| 206 | end do | ||
| 207 | end do | ||
| 208 | end do | ||
| 209 | end do | ||
| 210 | |||
| 211 | 29 | end subroutine get_partial_add_bias_val | |
| 212 | !############################################################################### | ||
| 213 | |||
| 214 | |||
| 215 | !############################################################################### | ||
| 216 | 5 | module function piecewise_array(input, gradient, limit) result(output) | |
| 217 | !! Apply piecewise activation function to input array | ||
| 218 | implicit none | ||
| 219 | |||
| 220 | ! Arguments | ||
| 221 | class(array_type), intent(in), target :: input | ||
| 222 | real(real32), intent(in) :: gradient | ||
| 223 | real(real32), intent(in) :: limit | ||
| 224 | type(array_type), pointer :: output | ||
| 225 | type(array_type), pointer :: b_array | ||
| 226 | |||
| 227 | 5 | output => input%create_result() | |
| 228 |
63/86✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
✓ Branch 54 taken 5 times.
✓ Branch 55 taken 5 times.
✓ Branch 56 taken 8 times.
✓ Branch 57 taken 5 times.
✓ Branch 58 taken 5 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 5 times.
✓ Branch 61 taken 5 times.
✓ Branch 62 taken 8 times.
✓ Branch 63 taken 5 times.
✓ Branch 64 taken 3 times.
✓ Branch 65 taken 5 times.
✓ Branch 66 taken 5 times.
✓ Branch 67 taken 5 times.
✓ Branch 68 taken 8 times.
✓ Branch 69 taken 5 times.
✓ Branch 70 taken 3 times.
✓ Branch 71 taken 5 times.
✓ Branch 72 taken 5 times.
✓ Branch 73 taken 5 times.
✓ Branch 74 taken 8 times.
✓ Branch 75 taken 5 times.
✓ Branch 76 taken 5 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 5 times.
✓ Branch 79 taken 5 times.
✓ Branch 80 taken 8 times.
✓ Branch 81 taken 5 times.
✗ Branch 82 not taken.
✓ Branch 83 taken 8 times.
✓ Branch 84 taken 5 times.
✓ Branch 85 taken 5 times.
✓ Branch 86 taken 8 times.
✓ Branch 87 taken 5 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 8 times.
✓ Branch 90 taken 5 times.
✗ Branch 91 not taken.
✓ Branch 92 taken 5 times.
✓ Branch 93 taken 5 times.
✓ Branch 94 taken 8 times.
✓ Branch 95 taken 5 times.
✓ Branch 96 taken 5 times.
✓ Branch 97 taken 3 times.
✓ Branch 98 taken 5 times.
✓ Branch 99 taken 5 times.
✓ Branch 100 taken 8 times.
✓ Branch 101 taken 5 times.
✓ Branch 102 taken 5 times.
✓ Branch 103 taken 3 times.
|
129 | where(input%val.ge.limit) |
| 229 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | output%val = gradient * (input%val - limit) + limit |
| 230 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | elsewhere(input%val.le.-limit) |
| 231 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | output%val = gradient * (input%val + limit) - limit |
| 232 | elsewhere | ||
| 233 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | output%val = input%val |
| 234 | end where | ||
| 235 | |||
| 236 | 5 | output%get_partial_left => get_partial_piecewise | |
| 237 | 5 | output%get_partial_left_val => get_partial_piecewise_val | |
| 238 |
1/2✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
|
5 | if(input%requires_grad)then |
| 239 | 5 | output%requires_grad = .true. | |
| 240 | 5 | output%is_forward = input%is_forward | |
| 241 | 5 | output%operation = 'piecewise' | |
| 242 | 5 | output%left_operand => input | |
| 243 | 5 | output%owns_left_operand = input%is_temporary | |
| 244 | end if | ||
| 245 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
|
5 | allocate(b_array) |
| 246 | 5 | b_array%is_sample_dependent = .false. | |
| 247 | 5 | b_array%requires_grad = .false. | |
| 248 | 5 | call b_array%allocate(array_shape=[2, 1]) | |
| 249 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
|
5 | b_array%val(1,1) = gradient |
| 250 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
|
5 | b_array%val(2,1) = limit |
| 251 | 5 | output%right_operand => b_array | |
| 252 | 5 | output%owns_right_operand = .true. | |
| 253 | |||
| 254 | 5 | end function piecewise_array | |
| 255 | !------------------------------------------------------------------------------- | ||
| 256 | ✗ | function get_partial_piecewise(this, upstream_grad) result(output) | |
| 257 | !! Get partial derivative of piecewise activation | ||
| 258 | implicit none | ||
| 259 | class(array_type), intent(inout) :: this | ||
| 260 | type(array_type), intent(in) :: upstream_grad | ||
| 261 | type(array_type) :: output | ||
| 262 | |||
| 263 | type(array_type), pointer :: ptr | ||
| 264 | |||
| 265 | ptr => merge( & | ||
| 266 | upstream_grad, & | ||
| 267 | ✗ | upstream_grad * this%right_operand%val(1,1), & | |
| 268 | ✗ | this%left_operand%val.le.-this%right_operand%val(2,1) .or. & | |
| 269 | ✗ | this%left_operand%val.ge.this%right_operand%val(2,1) & | |
| 270 | ✗ | ) | |
| 271 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 272 | |||
| 273 | ✗ | end function get_partial_piecewise | |
| 274 | !------------------------------------------------------------------------------- | ||
| 275 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | pure subroutine get_partial_piecewise_val(this, upstream_grad, output) |
| 276 | !! Get partial derivative of piecewise activation (in-place version) | ||
| 277 | implicit none | ||
| 278 | class(array_type), intent(in) :: this | ||
| 279 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 280 | real(real32), dimension(:,:), intent(out) :: output | ||
| 281 | |||
| 282 | 96 | where(this%left_operand%val.le.this%right_operand%val(2,1) .or. & | |
| 283 | 16 | this%left_operand%val.ge.-this%right_operand%val(2,1) & | |
| 284 |
33/62✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 4 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 4 times.
✗ Branch 75 not taken.
✓ Branch 76 taken 4 times.
✗ Branch 78 not taken.
✓ Branch 79 taken 4 times.
✗ Branch 81 not taken.
✓ Branch 82 taken 4 times.
✓ Branch 84 taken 4 times.
✓ Branch 85 taken 4 times.
✓ Branch 86 taken 4 times.
✓ Branch 87 taken 4 times.
✓ Branch 88 taken 4 times.
✗ Branch 89 not taken.
|
16 | ) |
| 285 |
16/32✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
|
4 | output = upstream_grad |
| 286 | elsewhere | ||
| 287 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
|
8 | output = upstream_grad * this%right_operand%val(1,1) |
| 288 | end where | ||
| 289 | |||
| 290 | 4 | end subroutine get_partial_piecewise_val | |
| 291 | !############################################################################### | ||
| 292 | |||
| 293 | |||
| 294 | !############################################################################### | ||
| 295 | 15 | module function softmax_array(input, dim) result(output) | |
| 296 | implicit none | ||
| 297 | class(array_type), intent(in), target :: input | ||
| 298 | integer, intent(in) :: dim | ||
| 299 | type(array_type), pointer :: output | ||
| 300 | |||
| 301 | integer :: i | ||
| 302 | |||
| 303 | 15 | output => input%create_result() | |
| 304 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 13 times.
|
15 | if(dim.eq.1)then |
| 305 |
2/2✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
|
8 | do i = 1, size(input%val, 1) |
| 306 |
29/54✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 6 times.
✓ Branch 45 taken 6 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 6 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 3 times.
✓ Branch 52 taken 6 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 3 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 6 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 6 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 6 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 6 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 6 times.
✓ Branch 69 taken 9 times.
✓ Branch 70 taken 6 times.
✓ Branch 71 taken 9 times.
✓ Branch 72 taken 6 times.
|
33 | output%val(i, :) = exp(input%val(i, :) - maxval(input%val(i,:))) |
| 307 |
23/42✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 6 times.
✓ Branch 45 taken 9 times.
✓ Branch 46 taken 6 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 6 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 6 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 6 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 6 times.
✓ Branch 59 taken 9 times.
✓ Branch 60 taken 6 times.
|
26 | output%val(i, :) = output%val(i, :) / sum(output%val(i, :)) |
| 308 | end do | ||
| 309 |
1/2✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
|
13 | elseif(dim.eq.2)then |
| 310 |
2/2✓ Branch 0 taken 31 times.
✓ Branch 1 taken 13 times.
|
44 | do i = 1, size(input%val, 2) |
| 311 |
30/54✗ Branch 0 not taken.
✓ Branch 1 taken 31 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 31 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 31 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 31 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 31 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 31 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 31 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 31 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 31 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 31 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 31 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 31 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 31 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 31 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 31 times.
✓ Branch 45 taken 31 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 31 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 55 times.
✓ Branch 52 taken 31 times.
✓ Branch 53 taken 23 times.
✓ Branch 54 taken 32 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 31 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 31 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 31 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 31 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 31 times.
✓ Branch 69 taken 86 times.
✓ Branch 70 taken 31 times.
✓ Branch 71 taken 86 times.
✓ Branch 72 taken 31 times.
|
289 | output%val(:, i) = exp(input%val(:, i) - maxval(input%val(:, i))) |
| 312 |
23/42✗ Branch 0 not taken.
✓ Branch 1 taken 31 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 31 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 31 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 31 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 31 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 31 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 31 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 31 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 31 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 31 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 31 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 31 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 31 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 31 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 31 times.
✓ Branch 45 taken 86 times.
✓ Branch 46 taken 31 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 31 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 31 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 31 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 31 times.
✓ Branch 59 taken 86 times.
✓ Branch 60 taken 31 times.
|
216 | output%val(:, i) = output%val(:, i) / sum(output%val(:, i)) |
| 313 | end do | ||
| 314 | else | ||
| 315 | ✗ | call stop_program("softmax_array: Unsupported dimension") | |
| 316 | end if | ||
| 317 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 15 times.
|
15 | allocate(output%indices(1)) |
| 318 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
|
15 | output%indices(1) = dim |
| 319 | |||
| 320 | 15 | output%get_partial_left => get_partial_softmax | |
| 321 | 15 | output%get_partial_left_val => get_partial_softmax_val | |
| 322 | 15 | output%get_partial_left_val_sum => get_partial_softmax_val_sum | |
| 323 |
1/2✓ Branch 0 taken 15 times.
✗ Branch 1 not taken.
|
15 | if(input%requires_grad)then |
| 324 | 15 | output%requires_grad = .true. | |
| 325 | 15 | output%is_forward = input%is_forward | |
| 326 | 15 | output%operation = 'softmax' | |
| 327 | 15 | output%left_operand => input | |
| 328 | 15 | output%owns_left_operand = input%is_temporary | |
| 329 | end if | ||
| 330 | |||
| 331 | 15 | end function softmax_array | |
| 332 | !------------------------------------------------------------------------------- | ||
| 333 | ✗ | function get_partial_softmax(this, upstream_grad) result(output) | |
| 334 | !! Get partial derivative of softmax activation | ||
| 335 | implicit none | ||
| 336 | class(array_type), intent(inout) :: this | ||
| 337 | type(array_type), intent(in) :: upstream_grad | ||
| 338 | type(array_type) :: output | ||
| 339 | type(array_type), pointer :: ptr | ||
| 340 | |||
| 341 | integer :: dim | ||
| 342 | |||
| 343 | ✗ | if(this%indices(1).eq.1)then | |
| 344 | ✗ | dim = 2 | |
| 345 | else | ||
| 346 | ✗ | dim = 1 | |
| 347 | end if | ||
| 348 | ! ptr => this * upstream_grad | ||
| 349 | ! ptr => ptr - this * sum(ptr, dim=dim) | ||
| 350 | ✗ | ptr => softmax_reverse_array(this, upstream_grad, this%indices(1)) | |
| 351 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 352 | |||
| 353 | ✗ | end function get_partial_softmax | |
| 354 | !------------------------------------------------------------------------------- | ||
| 355 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | pure subroutine get_partial_softmax_val(this, upstream_grad, output) |
| 356 | !! Get partial derivative of softmax activation (in-place version) | ||
| 357 | implicit none | ||
| 358 | class(array_type), intent(in) :: this | ||
| 359 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 360 | real(real32), dimension(:,:), intent(out) :: output | ||
| 361 | |||
| 362 | integer :: s, dim | ||
| 363 | |||
| 364 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
|
4 | if(this%indices(1).eq.1)then |
| 365 | ✗ | dim = 2 | |
| 366 | else | ||
| 367 | 4 | dim = 1 | |
| 368 | end if | ||
| 369 |
28/52✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 4 times.
✓ Branch 72 taken 4 times.
✓ Branch 73 taken 4 times.
✓ Branch 74 taken 4 times.
✓ Branch 75 taken 4 times.
|
12 | output = this%val * upstream_grad |
| 370 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if(dim.eq.1)then |
| 371 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
|
8 | do s = 1, size(this%val, 2) |
| 372 |
27/50✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✓ Branch 51 taken 4 times.
✓ Branch 52 taken 4 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 4 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 4 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 4 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 4 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 4 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 4 times.
✓ Branch 71 taken 4 times.
✓ Branch 72 taken 4 times.
|
16 | output(:, s) = output(:, s) - this%val(:, s) * sum(output(:, s)) |
| 373 | end do | ||
| 374 | ✗ | elseif(dim.eq.2)then | |
| 375 | ✗ | do s = 1, size(this%val, 1) | |
| 376 | ✗ | output(s, :) = output(s, :) - this%val(s, :) * sum(output(s, :)) | |
| 377 | end do | ||
| 378 | end if | ||
| 379 | 4 | end subroutine get_partial_softmax_val | |
| 380 | !------------------------------------------------------------------------------- | ||
| 381 | ✗ | pure subroutine get_partial_softmax_val_sum(this, upstream_grad, output) | |
| 382 | !! Get partial derivative of softmax activation (in-place version, summed over samples) | ||
| 383 | implicit none | ||
| 384 | class(array_type), intent(in) :: this | ||
| 385 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 386 | real(real32), dimension(:), intent(out) :: output | ||
| 387 | |||
| 388 | integer :: s, i, nfeat, nsamp | ||
| 389 | real(real32) :: dot | ||
| 390 | |||
| 391 | ✗ | output = 0.0_real32 | |
| 392 | ✗ | if(this%indices(1) .eq. 1)then | |
| 393 | ✗ | nsamp = size(this%val,1) | |
| 394 | ✗ | nfeat = size(this%val,2) | |
| 395 | ✗ | do s = 1, nsamp | |
| 396 | ! compute g·y | ||
| 397 | ✗ | dot = 0.0_real32 | |
| 398 | ✗ | do i = 1, nfeat | |
| 399 | ✗ | dot = dot + upstream_grad(s,i) * this%val(s,i) | |
| 400 | end do | ||
| 401 | ! accumulate reduced gradient | ||
| 402 | ✗ | do concurrent( i = 1 : nfeat ) | |
| 403 | ✗ | output(i) = output(i) + this%val(s,i) * (upstream_grad(s,i) - dot) | |
| 404 | end do | ||
| 405 | end do | ||
| 406 | else | ||
| 407 | ✗ | nsamp = size(this%val,2) | |
| 408 | ✗ | nfeat = size(this%val,1) | |
| 409 | ✗ | do s = 1, nsamp | |
| 410 | ✗ | dot = 0.0_real32 | |
| 411 | ✗ | do i = 1, nfeat | |
| 412 | ✗ | dot = dot + upstream_grad(i,s) * this%val(i,s) | |
| 413 | end do | ||
| 414 | ✗ | do concurrent( i = 1 : nfeat ) | |
| 415 | ✗ | output(i) = output(i) + this%val(i,s) * (upstream_grad(i,s) - dot) | |
| 416 | end do | ||
| 417 | end do | ||
| 418 | end if | ||
| 419 | ✗ | end subroutine get_partial_softmax_val_sum | |
| 420 | !############################################################################### | ||
| 421 | |||
| 422 | |||
| 423 | !############################################################################### | ||
| 424 | 10 | module function swish_array(input, beta) result(output) | |
| 425 | !! Swish activation function | ||
| 426 | implicit none | ||
| 427 | |||
| 428 | ! Arguments | ||
| 429 | class(array_type), intent(in), target :: input | ||
| 430 | real(real32), intent(in) :: beta | ||
| 431 | type(array_type), pointer :: output | ||
| 432 | type(array_type), pointer :: b_array | ||
| 433 | |||
| 434 | 10 | output => input%create_result() | |
| 435 |
30/54✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 10 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 10 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 10 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 10 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 10 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 10 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 10 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 10 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 10 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 10 times.
✓ Branch 54 taken 10 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 23 times.
✓ Branch 57 taken 10 times.
✓ Branch 58 taken 104 times.
✓ Branch 59 taken 23 times.
✓ Branch 60 taken 10 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 10 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 10 times.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✓ Branch 68 taken 23 times.
✓ Branch 69 taken 10 times.
✓ Branch 70 taken 104 times.
✓ Branch 71 taken 23 times.
|
264 | output%val = input%val * (1._real32 / (1._real32 + exp(-beta * input%val))) |
| 436 | |||
| 437 | 10 | output%get_partial_left => get_partial_swish | |
| 438 | 10 | output%get_partial_left_val => get_partial_swish_val | |
| 439 |
1/2✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
|
10 | if(input%requires_grad)then |
| 440 | 10 | output%requires_grad = .true. | |
| 441 | 10 | output%is_forward = input%is_forward | |
| 442 | 10 | output%operation = 'swish' | |
| 443 | 10 | output%left_operand => input | |
| 444 | 10 | output%owns_left_operand = input%is_temporary | |
| 445 | end if | ||
| 446 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
|
10 | allocate(b_array) |
| 447 | 10 | b_array%is_sample_dependent = .false. | |
| 448 | 10 | b_array%is_scalar = .true. | |
| 449 | 10 | b_array%requires_grad = .false. | |
| 450 | 10 | call b_array%allocate(array_shape=[1, 1]) | |
| 451 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
|
10 | b_array%val(1,1) = beta |
| 452 | 10 | output%right_operand => b_array | |
| 453 | 10 | output%owns_right_operand = .true. | |
| 454 | |||
| 455 | 10 | end function swish_array | |
| 456 | !------------------------------------------------------------------------------- | ||
| 457 | ✗ | function get_partial_swish(this, upstream_grad) result(output) | |
| 458 | !! Get partial derivative of swish activation | ||
| 459 | implicit none | ||
| 460 | class(array_type), intent(inout) :: this | ||
| 461 | type(array_type), intent(in) :: upstream_grad | ||
| 462 | type(array_type) :: output | ||
| 463 | |||
| 464 | type(array_type), pointer :: ptr | ||
| 465 | type(array_type), pointer :: exp_term | ||
| 466 | |||
| 467 | ✗ | exp_term => exp(this%right_operand%val(1,1) * this%left_operand) | |
| 468 | |||
| 469 | ptr => upstream_grad * exp_term * ( & | ||
| 470 | ✗ | this%right_operand%val(1,1) * this%left_operand + & | |
| 471 | exp_term + 1._real32 & | ||
| 472 | ✗ | ) / ( ( exp_term + 1._real32 )**2._real32 ) | |
| 473 | |||
| 474 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 475 | ✗ | end function get_partial_swish | |
| 476 | !------------------------------------------------------------------------------- | ||
| 477 |
2/4✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
|
4 | pure subroutine get_partial_swish_val(this, upstream_grad, output) |
| 478 | !! Get partial derivative of swish activation (in-place version) | ||
| 479 | implicit none | ||
| 480 | class(array_type), intent(in) :: this | ||
| 481 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 482 | real(real32), dimension(:,:), intent(out) :: output | ||
| 483 | |||
| 484 | 8 | real(real32), dimension(size(this%val,1), size(this%val,2)) :: exp_term | |
| 485 | |||
| 486 |
24/44✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✓ Branch 60 taken 4 times.
✓ Branch 61 taken 4 times.
✓ Branch 62 taken 4 times.
✓ Branch 63 taken 4 times.
|
12 | exp_term = exp(this%right_operand%val(1,1) * this%left_operand%val) |
| 487 | ✗ | output = upstream_grad * exp_term * ( & | |
| 488 | 56 | this%right_operand%val(1,1) * this%left_operand%val + & | |
| 489 | ✗ | exp_term + 1._real32 & | |
| 490 |
56/108✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 4 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 4 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 4 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 4 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 4 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 4 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 4 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 4 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 4 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 4 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 4 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 4 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 4 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 4 times.
✗ Branch 70 not taken.
✓ Branch 71 taken 4 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 4 times.
✗ Branch 74 not taken.
✓ Branch 75 taken 4 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 4 times.
✗ Branch 80 not taken.
✓ Branch 81 taken 4 times.
✗ Branch 83 not taken.
✓ Branch 84 taken 4 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 4 times.
✗ Branch 89 not taken.
✓ Branch 90 taken 4 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 4 times.
✗ Branch 95 not taken.
✓ Branch 96 taken 4 times.
✗ Branch 98 not taken.
✓ Branch 99 taken 4 times.
✗ Branch 100 not taken.
✓ Branch 101 taken 4 times.
✗ Branch 102 not taken.
✓ Branch 103 taken 4 times.
✗ Branch 104 not taken.
✓ Branch 105 taken 4 times.
✗ Branch 106 not taken.
✓ Branch 107 taken 4 times.
✗ Branch 108 not taken.
✓ Branch 109 taken 4 times.
✗ Branch 110 not taken.
✓ Branch 111 taken 4 times.
✗ Branch 112 not taken.
✓ Branch 113 taken 4 times.
✗ Branch 114 not taken.
✓ Branch 115 taken 4 times.
✗ Branch 117 not taken.
✓ Branch 118 taken 4 times.
✗ Branch 120 not taken.
✓ Branch 121 taken 4 times.
✗ Branch 123 not taken.
✓ Branch 124 taken 4 times.
✓ Branch 126 taken 4 times.
✓ Branch 127 taken 4 times.
✓ Branch 128 taken 4 times.
✓ Branch 129 taken 4 times.
|
12 | ) / ( ( exp_term + 1._real32 )**2._real32 ) |
| 491 | |||
| 492 | 4 | end subroutine get_partial_swish_val | |
| 493 | !############################################################################### | ||
| 494 | |||
| 495 | |||
| 496 | !############################################################################### | ||
| 497 | ✗ | function softmax_reverse_array(softmax, gradient, dim) result(output) | |
| 498 | !! Softmax function for reverse mode autodiff | ||
| 499 | implicit none | ||
| 500 | class(array_type), intent(in), target :: softmax | ||
| 501 | class(array_type), intent(in), target :: gradient | ||
| 502 | integer, intent(in) :: dim | ||
| 503 | type(array_type), pointer :: output | ||
| 504 | |||
| 505 | integer :: i | ||
| 506 | ✗ | real(real32), dimension(size(softmax%val,1), size(softmax%val,2)) :: temp_val | |
| 507 | |||
| 508 | |||
| 509 | ✗ | output => softmax%create_result() | |
| 510 | ✗ | temp_val = gradient%val * softmax%val | |
| 511 | ✗ | if(dim.eq.1)then | |
| 512 | ✗ | do concurrent(i=1:size(softmax%val,1)) | |
| 513 | ✗ | temp_val(i, :) = temp_val(i, :) - softmax%val(i, :) * sum(temp_val(i, :)) | |
| 514 | end do | ||
| 515 | ✗ | elseif(dim.eq.2)then | |
| 516 | ✗ | do concurrent(i=1:size(softmax%val,2)) | |
| 517 | ✗ | temp_val(:, i) = temp_val(:, i) - softmax%val(:, i) * sum(temp_val(:, i)) | |
| 518 | end do | ||
| 519 | else | ||
| 520 | ✗ | call stop_program("softmax_reverse_array: Unsupported dimension") | |
| 521 | end if | ||
| 522 | ✗ | output%val = temp_val | |
| 523 | ✗ | output%indices = [dim] | |
| 524 | |||
| 525 | ✗ | output%get_partial_left => get_partial_softmax_reverse_left | |
| 526 | ✗ | output%get_partial_left_val => get_partial_softmax_reverse_left_val | |
| 527 | ✗ | output%get_partial_right => get_partial_softmax_reverse_right | |
| 528 | ✗ | output%get_partial_right_val => get_partial_softmax_reverse_right_val | |
| 529 | ✗ | if(softmax%requires_grad .or. gradient%requires_grad)then | |
| 530 | ✗ | output%requires_grad = .true. | |
| 531 | ✗ | output%is_forward = softmax%is_forward .or. gradient%is_forward | |
| 532 | ✗ | output%operation = 'softmax_reverse' | |
| 533 | ✗ | output%left_operand => softmax | |
| 534 | ✗ | output%right_operand => gradient | |
| 535 | ✗ | output%owns_left_operand = softmax%is_temporary | |
| 536 | ✗ | output%owns_right_operand = gradient%is_temporary | |
| 537 | end if | ||
| 538 | |||
| 539 | ✗ | end function softmax_reverse_array | |
| 540 | !------------------------------------------------------------------------------- | ||
| 541 | ✗ | function get_partial_softmax_reverse_left(this, upstream_grad) result(output) | |
| 542 | !! Get partial derivative of softmax reverse operation | ||
| 543 | implicit none | ||
| 544 | class(array_type), intent(inout) :: this | ||
| 545 | type(array_type), intent(in) :: upstream_grad | ||
| 546 | type(array_type) :: output | ||
| 547 | |||
| 548 | type(array_type), pointer :: sum_yg, sum_yu | ||
| 549 | type(array_type), pointer :: ptr | ||
| 550 | |||
| 551 | ✗ | sum_yg => sum(this%left_operand * this%right_operand, dim=this%indices(1)) | |
| 552 | ✗ | sum_yu => sum(this%left_operand * upstream_grad, dim=this%indices(1)) | |
| 553 | |||
| 554 | ✗ | ptr => upstream_grad * (this%right_operand - sum_yg) - this%right_operand * sum_yu | |
| 555 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 556 | |||
| 557 | ✗ | end function get_partial_softmax_reverse_left | |
| 558 | !------------------------------------------------------------------------------- | ||
| 559 | ✗ | function get_partial_softmax_reverse_right(this, upstream_grad) result(output) | |
| 560 | !! Get partial derivative of softmax reverse operation | ||
| 561 | implicit none | ||
| 562 | class(array_type), intent(inout) :: this | ||
| 563 | type(array_type), intent(in) :: upstream_grad | ||
| 564 | type(array_type) :: output | ||
| 565 | |||
| 566 | type(array_type), pointer :: ptr | ||
| 567 | |||
| 568 | ptr => ( & | ||
| 569 | upstream_grad - & | ||
| 570 | ✗ | sum(this%left_operand * upstream_grad, dim=this%indices(1)) & | |
| 571 | ✗ | ) * this%left_operand | |
| 572 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 573 | |||
| 574 | ✗ | end function get_partial_softmax_reverse_right | |
| 575 | !------------------------------------------------------------------------------- | ||
| 576 | ✗ | pure subroutine get_partial_softmax_reverse_left_val(this, upstream_grad, output) | |
| 577 | !! Get partial derivative of softmax reverse operation (in-place version) | ||
| 578 | implicit none | ||
| 579 | class(array_type), intent(in) :: this | ||
| 580 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 581 | real(real32), dimension(:,:), intent(out) :: output | ||
| 582 | |||
| 583 | integer :: dim, i | ||
| 584 | ✗ | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yg | |
| 585 | ✗ | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu | |
| 586 | |||
| 587 | ✗ | dim = this%indices(1) | |
| 588 | ✗ | sum_yg = sum(this%left_operand%val * this%right_operand%val, dim=dim) | |
| 589 | ✗ | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 590 | |||
| 591 | ✗ | if(dim.eq.1)then | |
| 592 | ✗ | do concurrent(i=1:size(this%val,2)) | |
| 593 | ✗ | output(:, i) = & | |
| 594 | ✗ | upstream_grad(:, i) * (this%right_operand%val(:, i) - sum_yg(i)) - & | |
| 595 | ✗ | this%right_operand%val(:, i) * sum_yu(i) | |
| 596 | end do | ||
| 597 | else | ||
| 598 | ✗ | do concurrent(i=1:size(this%val,1)) | |
| 599 | ✗ | output(i, :) = & | |
| 600 | ✗ | upstream_grad(i, :) * (this%right_operand%val(i, :) - sum_yg(i)) - & | |
| 601 | ✗ | this%right_operand%val(i, :) * sum_yu(i) | |
| 602 | end do | ||
| 603 | end if | ||
| 604 | |||
| 605 | ✗ | end subroutine get_partial_softmax_reverse_left_val | |
| 606 | !------------------------------------------------------------------------------- | ||
| 607 | ✗ | pure subroutine get_partial_softmax_reverse_right_val(this, upstream_grad, output) | |
| 608 | !! Get partial derivative of softmax reverse operation (in-place version) | ||
| 609 | implicit none | ||
| 610 | class(array_type), intent(in) :: this | ||
| 611 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 612 | real(real32), dimension(:,:), intent(out) :: output | ||
| 613 | |||
| 614 | integer :: dim, i | ||
| 615 | ✗ | real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu | |
| 616 | |||
| 617 | ✗ | dim = this%indices(1) | |
| 618 | ✗ | if(dim.eq.1)then | |
| 619 | ✗ | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 620 | ✗ | do concurrent(i=1:size(this%val,1)) | |
| 621 | ✗ | output(i, :) = upstream_grad(i, :) - sum_yu(i) * this%left_operand%val(i, :) | |
| 622 | end do | ||
| 623 | else | ||
| 624 | ✗ | sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim) | |
| 625 | ✗ | do concurrent(i=1:size(this%val,2)) | |
| 626 | ✗ | output(:, i) = upstream_grad(:, i) - sum_yu(i) * this%left_operand%val(:, i) | |
| 627 | end do | ||
| 628 | end if | ||
| 629 | ✗ | end subroutine get_partial_softmax_reverse_right_val | |
| 630 | !############################################################################### | ||
| 631 | |||
| 632 | end submodule athena__diffstruc_extd_submodule | ||
| 633 |