GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_diffstruc_extd_sub.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 154 285 54.0%
Functions: 0 0 -%
Branches: 691 2322 29.8%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule
2 !! Submodule containing implementations for extended diffstruc array operations
3 use coreutils, only: stop_program
4 use diffstruc, only: &
5 operator(+), operator(-), operator(*), concat, exp, sum, merge
6
7 contains
8
9 !###############################################################################
10
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 module function add_array_ptr(a, idx1, idx2) result(c)
11 !! Add two autodiff arrays
12 implicit none
13
14 ! Arguments
15 type(array_ptr_type), dimension(:), intent(in) :: a
16 integer, intent(in) :: idx1, idx2
17 type(array_type), pointer :: c
18
19 ! Local variables
20 integer :: i
21
22
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
1 c => a(1)%array(idx1, idx2) + a(2)%array(idx1, idx2)
23
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 do i = 3, size(a), 1
24
0/12
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
1 c => c + a(i)%array(idx1, idx2)
25 end do
26 1 end function add_array_ptr
27 !###############################################################################
28
29
30 !###############################################################################
31
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 module function concat_array_ptr(a, idx1, idx2, dim) result(c)
32 !! Concatenate two autodiff arrays along a specified dimension
33 implicit none
34
35 ! Arguments
36 type(array_ptr_type), dimension(:), intent(in) :: a
37 integer, intent(in) :: idx1, idx2, dim
38 type(array_type), pointer :: c
39
40 ! Local variables
41 integer :: i
42
43
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
1 c => concat(a(1)%array(idx1, idx2), a(2)%array(idx1, idx2), dim)
44
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 do i = 3, size(a), 1
45
0/12
✗ Branch 0 not taken.
✗ Branch 1 not taken.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
1 c => concat(c, a(i)%array(idx1, idx2), dim)
46 end do
47 1 end function concat_array_ptr
48 !###############################################################################
49
50
51 !##############################################################################!
52 ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * !
53 !##############################################################################!
54
55
56 !###############################################################################
57 43 module function add_bias(input, bias, dim, dim_act_on_shape) result(output)
58 !! Add bias to input array along specified dimension
59 implicit none
60
61 ! Arguments
62 class(array_type), intent(in), target :: input
63 class(array_type), intent(in), target :: bias
64 integer, intent(in) :: dim
65 logical, intent(in), optional :: dim_act_on_shape
66 type(array_type), pointer :: output
67
68 ! Local variables
69 integer :: i, j, k, s, idx, itmp1
70 integer :: num_elements_pre, num_elements_post, num_dims
71 logical :: dim_act_on_shape_
72
73
1/2
✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
43 if(present(dim_act_on_shape))then
74 43 dim_act_on_shape_ = dim_act_on_shape
75 else
76 dim_act_on_shape_ = .false.
77 end if
78
79 43 output => input%create_result()
80
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 43 times.
43 allocate(output%indices(2))
81
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
43 output%indices(1) = dim
82
1/2
✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
43 if(dim_act_on_shape_)then
83 43 num_dims = size(input%shape)
84
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
43 if(dim .gt. num_dims)then
85 call stop_program("Dimension for add_bias exceeds input dimensions")
86 return
87
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
43 elseif(size(bias%shape) .ne. 1)then
88 call stop_program("Bias must be a 1D array")
89 return
90 end if
91 43 num_elements_pre = 1
92 43 num_elements_post = 1
93
2/2
✓ Branch 0 taken 128 times.
✓ Branch 1 taken 43 times.
171 do i = 1, num_dims
94
2/2
✓ Branch 0 taken 85 times.
✓ Branch 1 taken 43 times.
171 if(i .lt. dim)then
95
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 85 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 85 times.
85 num_elements_pre = num_elements_pre * input%shape(i)
96
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
43 elseif(i .gt. dim)then
97 num_elements_post = num_elements_post * input%shape(i)
98 end if
99 end do
100
101
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
43 itmp1 = num_elements_pre * input%shape(dim)
102
2/2
✓ Branch 0 taken 68 times.
✓ Branch 1 taken 43 times.
111 do s = 1, size(input%val, 2)
103
2/2
✓ Branch 0 taken 68 times.
✓ Branch 1 taken 68 times.
179 do k = 1, num_elements_post
104
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
✓ Branch 6 taken 196 times.
✓ Branch 7 taken 68 times.
332 do j = 1, bias%shape(1)
105 196 idx = (j - 1) * num_elements_pre + (k - 1) * itmp1
106
2/2
✓ Branch 0 taken 3095 times.
✓ Branch 1 taken 196 times.
3359 do i = 1, num_elements_pre
107
12/24
✗ Branch 0 not taken.
✓ Branch 1 taken 3095 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3095 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3095 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3095 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3095 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3095 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3095 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3095 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 3095 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 3095 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3095 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3095 times.
3291 output%val(idx + i, s) = input%val(idx + i, s) + bias%val(j,1)
108 end do
109 end do
110 end do
111 end do
112
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 43 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 43 times.
43 output%indices(2) = 1
113 else
114 call stop_program("add_bias: dim_act_on_shape=.false. not implemented yet")
115 output%indices(2) = 0
116 end if
117
118 43 output%get_partial_left => get_partial_add
119 43 output%get_partial_right => get_partial_add_bias
120 43 output%get_partial_left_val => get_partial_add_val
121 43 output%get_partial_right_val => get_partial_add_bias_val
122
1/2
✓ Branch 0 taken 43 times.
✗ Branch 1 not taken.
43 if(input%requires_grad .or. bias%requires_grad)then
123 43 output%requires_grad = .true.
124 43 output%is_forward = input%is_forward .or. bias%is_forward
125 43 output%operation = 'add_bias'
126 43 output%left_operand => input
127 43 output%right_operand => bias
128 end if
129
130 43 end function add_bias
131 !-------------------------------------------------------------------------------
132 function get_partial_add(this, upstream_grad) result(output)
133 !! Get partial derivative with respect to left operand
134 implicit none
135 class(array_type), intent(inout) :: this
136 type(array_type), intent(in) :: upstream_grad
137 type(array_type) :: output
138
139 output = upstream_grad
140 end function get_partial_add
141 !-------------------------------------------------------------------------------
142
2/4
✓ Branch 0 taken 29 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✗ Branch 3 not taken.
29 pure subroutine get_partial_add_val(this, upstream_grad, output)
143 !! Get partial derivative with respect to left operand
144 implicit none
145 class(array_type), intent(in) :: this
146 real(real32), dimension(:,:), intent(in) :: upstream_grad
147 real(real32), dimension(:,:), intent(out) :: output
148
149
13/26
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
29 if(size(upstream_grad,2).ne.size(output,2))then
150 if(size(output,1).eq.1)then
151 output(1,1) = sum(upstream_grad)
152 else
153 output(:,1) = sum(upstream_grad, dim=2)
154 end if
155 else
156
19/38
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 29 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 29 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 29 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 29 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 29 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 29 times.
29 if(size(output,1).eq.1.and.size(output,1).ne.size(upstream_grad,1))then
157 output(1,:) = sum(upstream_grad,1)
158 else
159
18/32
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 29 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 29 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 29 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 29 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 29 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 29 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 29 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 29 times.
✓ Branch 42 taken 29 times.
✓ Branch 43 taken 29 times.
✓ Branch 44 taken 1372 times.
✓ Branch 45 taken 29 times.
1430 output = upstream_grad
160 end if
161 end if
162 29 end subroutine get_partial_add_val
163 !-------------------------------------------------------------------------------
164 function get_partial_add_bias(this, upstream_grad) result(output)
165 !! Get partial derivative with respect to bias operand
166 implicit none
167 class(array_type), intent(inout) :: this
168 type(array_type), intent(in) :: upstream_grad
169 type(array_type) :: output
170
171 call output%allocate(array_shape = [ this%right_operand%shape, 1 ])
172 call this%get_partial_right_val(upstream_grad%val, output%val)
173
174 end function get_partial_add_bias
175 !-------------------------------------------------------------------------------
176
2/4
✓ Branch 0 taken 29 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 29 times.
✗ Branch 3 not taken.
29 pure subroutine get_partial_add_bias_val(this, upstream_grad, output)
177 implicit none
178
179 ! Arguments
180 class(array_type), intent(in) :: this
181 real(real32), dimension(:,:), intent(in) :: upstream_grad
182 real(real32), dimension(:,:), intent(out) :: output
183
184 integer :: i, j, k, s, idx, itmp1
185 integer :: num_elements_pre, num_elements_post, num_dims
186
187 29 num_dims = size(this%left_operand%shape)
188 29 num_elements_pre = 1
189 29 num_elements_post = 1
190
2/2
✓ Branch 0 taken 97 times.
✓ Branch 1 taken 29 times.
126 do i = 1, num_dims
191
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 97 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 97 times.
✓ Branch 6 taken 68 times.
✓ Branch 7 taken 29 times.
126 if(i .lt. this%indices(1))then
192
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 68 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 68 times.
68 num_elements_pre = num_elements_pre * this%left_operand%shape(i)
193
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
29 elseif(i .gt. this%indices(1))then
194 num_elements_post = num_elements_post * this%left_operand%shape(i)
195 end if
196 end do
197
198
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
29 itmp1 = num_elements_pre * this%left_operand%shape(this%indices(1))
199
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✓ Branch 18 taken 29 times.
✓ Branch 19 taken 29 times.
✓ Branch 20 taken 61 times.
✓ Branch 21 taken 29 times.
119 output = 0._real32
200
8/14
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 29 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 29 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 29 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 29 times.
✓ Branch 18 taken 29 times.
✓ Branch 19 taken 29 times.
58 do s = 1, size(upstream_grad, 2)
201
2/2
✓ Branch 0 taken 29 times.
✓ Branch 1 taken 29 times.
87 do k = 1, num_elements_post
202
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 29 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 29 times.
✓ Branch 6 taken 61 times.
✓ Branch 7 taken 29 times.
119 do j = 1, this%right_operand%shape(1)
203 61 idx = (j - 1) * num_elements_pre + (k - 1) * itmp1
204
2/2
✓ Branch 0 taken 1372 times.
✓ Branch 1 taken 61 times.
1462 do i = 1, num_elements_pre
205
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 1372 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1372 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1372 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1372 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1372 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1372 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1372 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1372 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1372 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1372 times.
1433 output(j,1) = output(j,1) + upstream_grad(idx + i, s)
206 end do
207 end do
208 end do
209 end do
210
211 29 end subroutine get_partial_add_bias_val
212 !###############################################################################
213
214
215 !###############################################################################
216 5 module function piecewise_array(input, gradient, limit) result(output)
217 !! Apply piecewise activation function to input array
218 implicit none
219
220 ! Arguments
221 class(array_type), intent(in), target :: input
222 real(real32), intent(in) :: gradient
223 real(real32), intent(in) :: limit
224 type(array_type), pointer :: output
225 type(array_type), pointer :: b_array
226
227 5 output => input%create_result()
228
63/86
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
✓ Branch 54 taken 5 times.
✓ Branch 55 taken 5 times.
✓ Branch 56 taken 8 times.
✓ Branch 57 taken 5 times.
✓ Branch 58 taken 5 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 5 times.
✓ Branch 61 taken 5 times.
✓ Branch 62 taken 8 times.
✓ Branch 63 taken 5 times.
✓ Branch 64 taken 3 times.
✓ Branch 65 taken 5 times.
✓ Branch 66 taken 5 times.
✓ Branch 67 taken 5 times.
✓ Branch 68 taken 8 times.
✓ Branch 69 taken 5 times.
✓ Branch 70 taken 3 times.
✓ Branch 71 taken 5 times.
✓ Branch 72 taken 5 times.
✓ Branch 73 taken 5 times.
✓ Branch 74 taken 8 times.
✓ Branch 75 taken 5 times.
✓ Branch 76 taken 5 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 5 times.
✓ Branch 79 taken 5 times.
✓ Branch 80 taken 8 times.
✓ Branch 81 taken 5 times.
✗ Branch 82 not taken.
✓ Branch 83 taken 8 times.
✓ Branch 84 taken 5 times.
✓ Branch 85 taken 5 times.
✓ Branch 86 taken 8 times.
✓ Branch 87 taken 5 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 8 times.
✓ Branch 90 taken 5 times.
✗ Branch 91 not taken.
✓ Branch 92 taken 5 times.
✓ Branch 93 taken 5 times.
✓ Branch 94 taken 8 times.
✓ Branch 95 taken 5 times.
✓ Branch 96 taken 5 times.
✓ Branch 97 taken 3 times.
✓ Branch 98 taken 5 times.
✓ Branch 99 taken 5 times.
✓ Branch 100 taken 8 times.
✓ Branch 101 taken 5 times.
✓ Branch 102 taken 5 times.
✓ Branch 103 taken 3 times.
129 where(input%val.ge.limit)
229
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
5 output%val = gradient * (input%val - limit) + limit
230
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
5 elsewhere(input%val.le.-limit)
231
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
5 output%val = gradient * (input%val + limit) - limit
232 elsewhere
233
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
5 output%val = input%val
234 end where
235
236 5 output%get_partial_left => get_partial_piecewise
237 5 output%get_partial_left_val => get_partial_piecewise_val
238
1/2
✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
5 if(input%requires_grad)then
239 5 output%requires_grad = .true.
240 5 output%is_forward = input%is_forward
241 5 output%operation = 'piecewise'
242 5 output%left_operand => input
243 5 output%owns_left_operand = input%is_temporary
244 end if
245
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
5 allocate(b_array)
246 5 b_array%is_sample_dependent = .false.
247 5 b_array%requires_grad = .false.
248 5 call b_array%allocate(array_shape=[2, 1])
249
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
5 b_array%val(1,1) = gradient
250
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
5 b_array%val(2,1) = limit
251 5 output%right_operand => b_array
252 5 output%owns_right_operand = .true.
253
254 5 end function piecewise_array
255 !-------------------------------------------------------------------------------
256 function get_partial_piecewise(this, upstream_grad) result(output)
257 !! Get partial derivative of piecewise activation
258 implicit none
259 class(array_type), intent(inout) :: this
260 type(array_type), intent(in) :: upstream_grad
261 type(array_type) :: output
262
263 type(array_type), pointer :: ptr
264
265 ptr => merge( &
266 upstream_grad, &
267 upstream_grad * this%right_operand%val(1,1), &
268 this%left_operand%val.le.-this%right_operand%val(2,1) .or. &
269 this%left_operand%val.ge.this%right_operand%val(2,1) &
270 )
271 call output%assign_and_deallocate_source(ptr)
272
273 end function get_partial_piecewise
274 !-------------------------------------------------------------------------------
275
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 pure subroutine get_partial_piecewise_val(this, upstream_grad, output)
276 !! Get partial derivative of piecewise activation (in-place version)
277 implicit none
278 class(array_type), intent(in) :: this
279 real(real32), dimension(:,:), intent(in) :: upstream_grad
280 real(real32), dimension(:,:), intent(out) :: output
281
282 96 where(this%left_operand%val.le.this%right_operand%val(2,1) .or. &
283 16 this%left_operand%val.ge.-this%right_operand%val(2,1) &
284
33/62
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 4 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 4 times.
✗ Branch 75 not taken.
✓ Branch 76 taken 4 times.
✗ Branch 78 not taken.
✓ Branch 79 taken 4 times.
✗ Branch 81 not taken.
✓ Branch 82 taken 4 times.
✓ Branch 84 taken 4 times.
✓ Branch 85 taken 4 times.
✓ Branch 86 taken 4 times.
✓ Branch 87 taken 4 times.
✓ Branch 88 taken 4 times.
✗ Branch 89 not taken.
16 )
285
16/32
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
4 output = upstream_grad
286 elsewhere
287
18/36
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
8 output = upstream_grad * this%right_operand%val(1,1)
288 end where
289
290 4 end subroutine get_partial_piecewise_val
291 !###############################################################################
292
293
294 !###############################################################################
295 15 module function softmax_array(input, dim) result(output)
296 implicit none
297 class(array_type), intent(in), target :: input
298 integer, intent(in) :: dim
299 type(array_type), pointer :: output
300
301 integer :: i
302
303 15 output => input%create_result()
304
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 13 times.
15 if(dim.eq.1)then
305
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 do i = 1, size(input%val, 1)
306
29/54
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 6 times.
✓ Branch 45 taken 6 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 6 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 3 times.
✓ Branch 52 taken 6 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 3 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 6 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 6 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 6 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 6 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 6 times.
✓ Branch 69 taken 9 times.
✓ Branch 70 taken 6 times.
✓ Branch 71 taken 9 times.
✓ Branch 72 taken 6 times.
33 output%val(i, :) = exp(input%val(i, :) - maxval(input%val(i,:)))
307
23/42
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 6 times.
✓ Branch 45 taken 9 times.
✓ Branch 46 taken 6 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 6 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 6 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 6 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 6 times.
✓ Branch 59 taken 9 times.
✓ Branch 60 taken 6 times.
26 output%val(i, :) = output%val(i, :) / sum(output%val(i, :))
308 end do
309
1/2
✓ Branch 0 taken 13 times.
✗ Branch 1 not taken.
13 elseif(dim.eq.2)then
310
2/2
✓ Branch 0 taken 31 times.
✓ Branch 1 taken 13 times.
44 do i = 1, size(input%val, 2)
311
30/54
✗ Branch 0 not taken.
✓ Branch 1 taken 31 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 31 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 31 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 31 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 31 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 31 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 31 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 31 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 31 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 31 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 31 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 31 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 31 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 31 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 31 times.
✓ Branch 45 taken 31 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 31 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 55 times.
✓ Branch 52 taken 31 times.
✓ Branch 53 taken 23 times.
✓ Branch 54 taken 32 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 31 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 31 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 31 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 31 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 31 times.
✓ Branch 69 taken 86 times.
✓ Branch 70 taken 31 times.
✓ Branch 71 taken 86 times.
✓ Branch 72 taken 31 times.
289 output%val(:, i) = exp(input%val(:, i) - maxval(input%val(:, i)))
312
23/42
✗ Branch 0 not taken.
✓ Branch 1 taken 31 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 31 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 31 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 31 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 31 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 31 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 31 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 31 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 31 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 31 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 31 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 31 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 31 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 31 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 31 times.
✓ Branch 45 taken 86 times.
✓ Branch 46 taken 31 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 31 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 31 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 31 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 31 times.
✓ Branch 59 taken 86 times.
✓ Branch 60 taken 31 times.
216 output%val(:, i) = output%val(:, i) / sum(output%val(:, i))
313 end do
314 else
315 call stop_program("softmax_array: Unsupported dimension")
316 end if
317
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 15 times.
15 allocate(output%indices(1))
318
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 15 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 15 times.
15 output%indices(1) = dim
319
320 15 output%get_partial_left => get_partial_softmax
321 15 output%get_partial_left_val => get_partial_softmax_val
322 15 output%get_partial_left_val_sum => get_partial_softmax_val_sum
323
1/2
✓ Branch 0 taken 15 times.
✗ Branch 1 not taken.
15 if(input%requires_grad)then
324 15 output%requires_grad = .true.
325 15 output%is_forward = input%is_forward
326 15 output%operation = 'softmax'
327 15 output%left_operand => input
328 15 output%owns_left_operand = input%is_temporary
329 end if
330
331 15 end function softmax_array
332 !-------------------------------------------------------------------------------
333 function get_partial_softmax(this, upstream_grad) result(output)
334 !! Get partial derivative of softmax activation
335 implicit none
336 class(array_type), intent(inout) :: this
337 type(array_type), intent(in) :: upstream_grad
338 type(array_type) :: output
339 type(array_type), pointer :: ptr
340
341 integer :: dim
342
343 if(this%indices(1).eq.1)then
344 dim = 2
345 else
346 dim = 1
347 end if
348 ! ptr => this * upstream_grad
349 ! ptr => ptr - this * sum(ptr, dim=dim)
350 ptr => softmax_reverse_array(this, upstream_grad, this%indices(1))
351 call output%assign_and_deallocate_source(ptr)
352
353 end function get_partial_softmax
354 !-------------------------------------------------------------------------------
355
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 pure subroutine get_partial_softmax_val(this, upstream_grad, output)
356 !! Get partial derivative of softmax activation (in-place version)
357 implicit none
358 class(array_type), intent(in) :: this
359 real(real32), dimension(:,:), intent(in) :: upstream_grad
360 real(real32), dimension(:,:), intent(out) :: output
361
362 integer :: s, dim
363
364
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
4 if(this%indices(1).eq.1)then
365 dim = 2
366 else
367 4 dim = 1
368 end if
369
28/52
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 4 times.
✓ Branch 72 taken 4 times.
✓ Branch 73 taken 4 times.
✓ Branch 74 taken 4 times.
✓ Branch 75 taken 4 times.
12 output = this%val * upstream_grad
370
1/2
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
4 if(dim.eq.1)then
371
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
8 do s = 1, size(this%val, 2)
372
27/50
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✓ Branch 51 taken 4 times.
✓ Branch 52 taken 4 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 4 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 4 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 4 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 4 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 4 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 4 times.
✓ Branch 71 taken 4 times.
✓ Branch 72 taken 4 times.
16 output(:, s) = output(:, s) - this%val(:, s) * sum(output(:, s))
373 end do
374 elseif(dim.eq.2)then
375 do s = 1, size(this%val, 1)
376 output(s, :) = output(s, :) - this%val(s, :) * sum(output(s, :))
377 end do
378 end if
379 4 end subroutine get_partial_softmax_val
380 !-------------------------------------------------------------------------------
381 pure subroutine get_partial_softmax_val_sum(this, upstream_grad, output)
382 !! Get partial derivative of softmax activation (in-place version, summed over samples)
383 implicit none
384 class(array_type), intent(in) :: this
385 real(real32), dimension(:,:), intent(in) :: upstream_grad
386 real(real32), dimension(:), intent(out) :: output
387
388 integer :: s, i, nfeat, nsamp
389 real(real32) :: dot
390
391 output = 0.0_real32
392 if(this%indices(1) .eq. 1)then
393 nsamp = size(this%val,1)
394 nfeat = size(this%val,2)
395 do s = 1, nsamp
396 ! compute g·y
397 dot = 0.0_real32
398 do i = 1, nfeat
399 dot = dot + upstream_grad(s,i) * this%val(s,i)
400 end do
401 ! accumulate reduced gradient
402 do concurrent( i = 1 : nfeat )
403 output(i) = output(i) + this%val(s,i) * (upstream_grad(s,i) - dot)
404 end do
405 end do
406 else
407 nsamp = size(this%val,2)
408 nfeat = size(this%val,1)
409 do s = 1, nsamp
410 dot = 0.0_real32
411 do i = 1, nfeat
412 dot = dot + upstream_grad(i,s) * this%val(i,s)
413 end do
414 do concurrent( i = 1 : nfeat )
415 output(i) = output(i) + this%val(i,s) * (upstream_grad(i,s) - dot)
416 end do
417 end do
418 end if
419 end subroutine get_partial_softmax_val_sum
420 !###############################################################################
421
422
423 !###############################################################################
424 10 module function swish_array(input, beta) result(output)
425 !! Swish activation function
426 implicit none
427
428 ! Arguments
429 class(array_type), intent(in), target :: input
430 real(real32), intent(in) :: beta
431 type(array_type), pointer :: output
432 type(array_type), pointer :: b_array
433
434 10 output => input%create_result()
435
30/54
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 10 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 10 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 10 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 10 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 10 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 10 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 10 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 10 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 10 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 10 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 10 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 10 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 10 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 10 times.
✓ Branch 54 taken 10 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 23 times.
✓ Branch 57 taken 10 times.
✓ Branch 58 taken 104 times.
✓ Branch 59 taken 23 times.
✓ Branch 60 taken 10 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 10 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 10 times.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✗ Branch 67 not taken.
✓ Branch 68 taken 23 times.
✓ Branch 69 taken 10 times.
✓ Branch 70 taken 104 times.
✓ Branch 71 taken 23 times.
264 output%val = input%val * (1._real32 / (1._real32 + exp(-beta * input%val)))
436
437 10 output%get_partial_left => get_partial_swish
438 10 output%get_partial_left_val => get_partial_swish_val
439
1/2
✓ Branch 0 taken 10 times.
✗ Branch 1 not taken.
10 if(input%requires_grad)then
440 10 output%requires_grad = .true.
441 10 output%is_forward = input%is_forward
442 10 output%operation = 'swish'
443 10 output%left_operand => input
444 10 output%owns_left_operand = input%is_temporary
445 end if
446
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
10 allocate(b_array)
447 10 b_array%is_sample_dependent = .false.
448 10 b_array%is_scalar = .true.
449 10 b_array%requires_grad = .false.
450 10 call b_array%allocate(array_shape=[1, 1])
451
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 10 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 10 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 10 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 10 times.
10 b_array%val(1,1) = beta
452 10 output%right_operand => b_array
453 10 output%owns_right_operand = .true.
454
455 10 end function swish_array
456 !-------------------------------------------------------------------------------
457 function get_partial_swish(this, upstream_grad) result(output)
458 !! Get partial derivative of swish activation
459 implicit none
460 class(array_type), intent(inout) :: this
461 type(array_type), intent(in) :: upstream_grad
462 type(array_type) :: output
463
464 type(array_type), pointer :: ptr
465 type(array_type), pointer :: exp_term
466
467 exp_term => exp(this%right_operand%val(1,1) * this%left_operand)
468
469 ptr => upstream_grad * exp_term * ( &
470 this%right_operand%val(1,1) * this%left_operand + &
471 exp_term + 1._real32 &
472 ) / ( ( exp_term + 1._real32 )**2._real32 )
473
474 call output%assign_and_deallocate_source(ptr)
475 end function get_partial_swish
476 !-------------------------------------------------------------------------------
477
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
4 pure subroutine get_partial_swish_val(this, upstream_grad, output)
478 !! Get partial derivative of swish activation (in-place version)
479 implicit none
480 class(array_type), intent(in) :: this
481 real(real32), dimension(:,:), intent(in) :: upstream_grad
482 real(real32), dimension(:,:), intent(out) :: output
483
484 8 real(real32), dimension(size(this%val,1), size(this%val,2)) :: exp_term
485
486
24/44
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 4 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 4 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 4 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 4 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 4 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 4 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 4 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 4 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 4 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 4 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 4 times.
✓ Branch 60 taken 4 times.
✓ Branch 61 taken 4 times.
✓ Branch 62 taken 4 times.
✓ Branch 63 taken 4 times.
12 exp_term = exp(this%right_operand%val(1,1) * this%left_operand%val)
487 output = upstream_grad * exp_term * ( &
488 56 this%right_operand%val(1,1) * this%left_operand%val + &
489 exp_term + 1._real32 &
490
56/108
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 4 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 4 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 4 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 4 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 4 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 4 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 4 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 4 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 4 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 4 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 4 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 4 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 4 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 4 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 4 times.
✗ Branch 58 not taken.
✓ Branch 59 taken 4 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 4 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 4 times.
✗ Branch 64 not taken.
✓ Branch 65 taken 4 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 4 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 4 times.
✗ Branch 70 not taken.
✓ Branch 71 taken 4 times.
✗ Branch 72 not taken.
✓ Branch 73 taken 4 times.
✗ Branch 74 not taken.
✓ Branch 75 taken 4 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 4 times.
✗ Branch 80 not taken.
✓ Branch 81 taken 4 times.
✗ Branch 83 not taken.
✓ Branch 84 taken 4 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 4 times.
✗ Branch 89 not taken.
✓ Branch 90 taken 4 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 4 times.
✗ Branch 95 not taken.
✓ Branch 96 taken 4 times.
✗ Branch 98 not taken.
✓ Branch 99 taken 4 times.
✗ Branch 100 not taken.
✓ Branch 101 taken 4 times.
✗ Branch 102 not taken.
✓ Branch 103 taken 4 times.
✗ Branch 104 not taken.
✓ Branch 105 taken 4 times.
✗ Branch 106 not taken.
✓ Branch 107 taken 4 times.
✗ Branch 108 not taken.
✓ Branch 109 taken 4 times.
✗ Branch 110 not taken.
✓ Branch 111 taken 4 times.
✗ Branch 112 not taken.
✓ Branch 113 taken 4 times.
✗ Branch 114 not taken.
✓ Branch 115 taken 4 times.
✗ Branch 117 not taken.
✓ Branch 118 taken 4 times.
✗ Branch 120 not taken.
✓ Branch 121 taken 4 times.
✗ Branch 123 not taken.
✓ Branch 124 taken 4 times.
✓ Branch 126 taken 4 times.
✓ Branch 127 taken 4 times.
✓ Branch 128 taken 4 times.
✓ Branch 129 taken 4 times.
12 ) / ( ( exp_term + 1._real32 )**2._real32 )
491
492 4 end subroutine get_partial_swish_val
493 !###############################################################################
494
495
496 !###############################################################################
497 function softmax_reverse_array(softmax, gradient, dim) result(output)
498 !! Softmax function for reverse mode autodiff
499 implicit none
500 class(array_type), intent(in), target :: softmax
501 class(array_type), intent(in), target :: gradient
502 integer, intent(in) :: dim
503 type(array_type), pointer :: output
504
505 integer :: i
506 real(real32), dimension(size(softmax%val,1), size(softmax%val,2)) :: temp_val
507
508
509 output => softmax%create_result()
510 temp_val = gradient%val * softmax%val
511 if(dim.eq.1)then
512 do concurrent(i=1:size(softmax%val,1))
513 temp_val(i, :) = temp_val(i, :) - softmax%val(i, :) * sum(temp_val(i, :))
514 end do
515 elseif(dim.eq.2)then
516 do concurrent(i=1:size(softmax%val,2))
517 temp_val(:, i) = temp_val(:, i) - softmax%val(:, i) * sum(temp_val(:, i))
518 end do
519 else
520 call stop_program("softmax_reverse_array: Unsupported dimension")
521 end if
522 output%val = temp_val
523 output%indices = [dim]
524
525 output%get_partial_left => get_partial_softmax_reverse_left
526 output%get_partial_left_val => get_partial_softmax_reverse_left_val
527 output%get_partial_right => get_partial_softmax_reverse_right
528 output%get_partial_right_val => get_partial_softmax_reverse_right_val
529 if(softmax%requires_grad .or. gradient%requires_grad)then
530 output%requires_grad = .true.
531 output%is_forward = softmax%is_forward .or. gradient%is_forward
532 output%operation = 'softmax_reverse'
533 output%left_operand => softmax
534 output%right_operand => gradient
535 output%owns_left_operand = softmax%is_temporary
536 output%owns_right_operand = gradient%is_temporary
537 end if
538
539 end function softmax_reverse_array
540 !-------------------------------------------------------------------------------
541 function get_partial_softmax_reverse_left(this, upstream_grad) result(output)
542 !! Get partial derivative of softmax reverse operation
543 implicit none
544 class(array_type), intent(inout) :: this
545 type(array_type), intent(in) :: upstream_grad
546 type(array_type) :: output
547
548 type(array_type), pointer :: sum_yg, sum_yu
549 type(array_type), pointer :: ptr
550
551 sum_yg => sum(this%left_operand * this%right_operand, dim=this%indices(1))
552 sum_yu => sum(this%left_operand * upstream_grad, dim=this%indices(1))
553
554 ptr => upstream_grad * (this%right_operand - sum_yg) - this%right_operand * sum_yu
555 call output%assign_and_deallocate_source(ptr)
556
557 end function get_partial_softmax_reverse_left
558 !-------------------------------------------------------------------------------
559 function get_partial_softmax_reverse_right(this, upstream_grad) result(output)
560 !! Get partial derivative of softmax reverse operation
561 implicit none
562 class(array_type), intent(inout) :: this
563 type(array_type), intent(in) :: upstream_grad
564 type(array_type) :: output
565
566 type(array_type), pointer :: ptr
567
568 ptr => ( &
569 upstream_grad - &
570 sum(this%left_operand * upstream_grad, dim=this%indices(1)) &
571 ) * this%left_operand
572 call output%assign_and_deallocate_source(ptr)
573
574 end function get_partial_softmax_reverse_right
575 !-------------------------------------------------------------------------------
576 pure subroutine get_partial_softmax_reverse_left_val(this, upstream_grad, output)
577 !! Get partial derivative of softmax reverse operation (in-place version)
578 implicit none
579 class(array_type), intent(in) :: this
580 real(real32), dimension(:,:), intent(in) :: upstream_grad
581 real(real32), dimension(:,:), intent(out) :: output
582
583 integer :: dim, i
584 real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yg
585 real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu
586
587 dim = this%indices(1)
588 sum_yg = sum(this%left_operand%val * this%right_operand%val, dim=dim)
589 sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim)
590
591 if(dim.eq.1)then
592 do concurrent(i=1:size(this%val,2))
593 output(:, i) = &
594 upstream_grad(:, i) * (this%right_operand%val(:, i) - sum_yg(i)) - &
595 this%right_operand%val(:, i) * sum_yu(i)
596 end do
597 else
598 do concurrent(i=1:size(this%val,1))
599 output(i, :) = &
600 upstream_grad(i, :) * (this%right_operand%val(i, :) - sum_yg(i)) - &
601 this%right_operand%val(i, :) * sum_yu(i)
602 end do
603 end if
604
605 end subroutine get_partial_softmax_reverse_left_val
606 !-------------------------------------------------------------------------------
607 pure subroutine get_partial_softmax_reverse_right_val(this, upstream_grad, output)
608 !! Get partial derivative of softmax reverse operation (in-place version)
609 implicit none
610 class(array_type), intent(in) :: this
611 real(real32), dimension(:,:), intent(in) :: upstream_grad
612 real(real32), dimension(:,:), intent(out) :: output
613
614 integer :: dim, i
615 real(real32), dimension(size(this%val,3-this%indices(1))) :: sum_yu
616
617 dim = this%indices(1)
618 if(dim.eq.1)then
619 sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim)
620 do concurrent(i=1:size(this%val,1))
621 output(i, :) = upstream_grad(i, :) - sum_yu(i) * this%left_operand%val(i, :)
622 end do
623 else
624 sum_yu = sum(this%left_operand%val * upstream_grad, dim=dim)
625 do concurrent(i=1:size(this%val,2))
626 output(:, i) = upstream_grad(:, i) - sum_yu(i) * this%left_operand%val(:, i)
627 end do
628 end if
629 end subroutine get_partial_softmax_reverse_right_val
630 !###############################################################################
631
632 end submodule athena__diffstruc_extd_submodule
633