GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_diffstruc_extd_sub_batchnorm.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 74 107 69.2%
Functions: 0 0 -%
Branches: 374 842 44.4%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_batchnorm
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function batchnorm_inference( &
8 input, params, mean, variance, epsilon &
9 ) result( output )
10 implicit none
11 class(array_type), intent(in), target :: input
12 class(array_type), intent(in), target :: params
13 real(real32), dimension(:), intent(in) :: mean
14 real(real32), dimension(:), intent(in) :: variance
15 real(real32), intent(in) :: epsilon
16 type(batchnorm_array_type), pointer :: output
17
18 integer :: i, c, s
19 integer :: num_elements, num_dims
20
21 allocate(output)
22 if(output%allocated) call output%deallocate()
23 call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
24 output%epsilon = epsilon
25 output%mean = mean
26 output%variance = variance
27 num_dims = size(input%shape)
28 num_elements = product(input%shape(1:num_dims - 1))
29 do concurrent(c = 1:input%shape(num_dims))
30 do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
31 output%val(i + (c-1) * num_elements, s) = &
32 params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - &
33 mean(c) ) / sqrt(variance(c) + output%epsilon) + &
34 params%val(c+input%shape(num_dims),1)
35 end do
36 end do
37
38 end function batchnorm_inference
39 !-------------------------------------------------------------------------------
40 9 module function batchnorm( &
41
2/4
✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 9 times.
✗ Branch 3 not taken.
9 input, params, momentum, mean, variance, epsilon &
42 ) result( output )
43 !! Batch normalisation operation
44 implicit none
45
46 ! Arguments
47 class(array_type), intent(in), target :: input
48 class(array_type), intent(in), target :: params
49 real(real32), intent(in) :: momentum
50 real(real32), dimension(:), intent(in) :: mean
51 real(real32), dimension(:), intent(in) :: variance
52 real(real32), intent(in) :: epsilon
53 type(batchnorm_array_type), pointer :: output
54
55 ! Local variables
56 integer :: i, c, s
57 integer :: num_elements, num_dims
58 real(real32) :: mu, var, norm
59
60
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
9 allocate(output)
61
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
9 if(output%allocated) call output%deallocate()
62
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 9 times.
✓ Branch 15 taken 26 times.
✓ Branch 16 taken 9 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 9 times.
✓ Branch 19 taken 35 times.
✓ Branch 20 taken 9 times.
70 call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
63 9 output%epsilon = epsilon
64
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 9 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 32 times.
✓ Branch 16 taken 9 times.
41 output%mean = mean
65
7/14
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 9 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 32 times.
✓ Branch 16 taken 9 times.
41 output%variance = variance
66 9 num_dims = size(input%shape)
67
6/10
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 9 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 9 times.
✓ Branch 12 taken 17 times.
✓ Branch 13 taken 9 times.
26 num_elements = product(input%shape(1:num_dims - 1))
68 9 norm = real(num_elements * size(input%val,2), real32)
69
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 9 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 9 times.
9 do concurrent(c = 1:input%shape(num_dims))
70 32 mu = 0._real32
71 32 var = 0._real32
72
12/20
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 32 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 32 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 32 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 32 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 32 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 32 times.
✓ Branch 24 taken 104 times.
✓ Branch 25 taken 32 times.
✓ Branch 26 taken 3558 times.
✓ Branch 27 taken 104 times.
3694 mu = sum(input%val((c-1) * num_elements+1:c*num_elements,:)) / norm
73 256 var = sum( (input%val((c-1) * num_elements+1:c*num_elements,:) - mu) ** 2 ) / &
74
12/20
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 32 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 32 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 32 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 32 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 32 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 32 times.
✓ Branch 24 taken 104 times.
✓ Branch 25 taken 32 times.
✓ Branch 26 taken 3558 times.
✓ Branch 27 taken 104 times.
3694 norm
75
76
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
32 if(momentum .gt. 1.E-8_real32)then
77 output%mean(c) = momentum * mean(c) + (1._real32 - momentum) * mu
78 output%variance(c) = momentum * variance(c) + (1._real32 - momentum) * var
79 else
80
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
32 output%mean(c) = mu
81
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 32 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 32 times.
32 output%variance(c) = var
82 end if
83
84
2/2
✓ Branch 0 taken 32 times.
✓ Branch 1 taken 9 times.
41 do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
85 output%val(i + (c-1) * num_elements, s) = &
86 28464 params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - mu ) / &
87
22/40
✓ Branch 0 taken 3486 times.
✓ Branch 1 taken 32 times.
✓ Branch 2 taken 3558 times.
✓ Branch 3 taken 3486 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3558 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3558 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3558 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3558 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3558 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3558 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 3558 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 3558 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3558 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 3558 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 3558 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 3558 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 3558 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 3558 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 3558 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 3558 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 3558 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 3558 times.
7076 sqrt(var + output%epsilon) + params%val(c+input%shape(num_dims),1)
88 end do
89 end do
90
91 9 output%get_partial_left => get_partial_batchnorm_left
92 9 output%get_partial_left_val => get_partial_batchnorm_left_val
93 9 output%get_partial_right => get_partial_batchnorm_right
94 9 output%get_partial_right_val => get_partial_batchnorm_right_val
95
1/2
✓ Branch 0 taken 9 times.
✗ Branch 1 not taken.
9 if(input%requires_grad .or. params%requires_grad)then
96 9 output%requires_grad = .true.
97 9 output%is_forward = input%is_forward .or. params%is_forward
98 9 output%operation = 'batchnorm'
99 9 output%left_operand => input
100 9 output%right_operand => params
101 end if
102
103 9 end function batchnorm
104 !-------------------------------------------------------------------------------
105 function get_partial_batchnorm_left(this, upstream_grad) result(output)
106 implicit none
107 class(array_type), intent(inout) :: this
108 type(array_type), intent(in) :: upstream_grad
109 type(array_type) :: output
110
111 class(array_type), pointer :: input
112
113 input => this%left_operand
114 call output%allocate(array_shape = [ input%shape, size(upstream_grad%val,2) ])
115
116 call this%get_partial_left_val(upstream_grad%val, output%val)
117
118 end function get_partial_batchnorm_left
119 !-------------------------------------------------------------------------------
120 function get_partial_batchnorm_right(this, upstream_grad) result(output)
121 implicit none
122 class(array_type), intent(inout) :: this
123 type(array_type), intent(in) :: upstream_grad
124 type(array_type) :: output
125
126 class(array_type), pointer :: params
127
128 params => this%right_operand
129 call output%allocate(array_shape = [ params%shape, 1 ])
130
131 call this%get_partial_right_val(upstream_grad%val, output%val)
132
133 end function get_partial_batchnorm_right
134 !-------------------------------------------------------------------------------
135
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 pure subroutine get_partial_batchnorm_left_val(this, upstream_grad, output)
136 !! Get partial derivative wrt input for batchnorm (subroutine version)
137 implicit none
138
139 class(array_type), intent(in) :: this
140 real(real32), dimension(:,:), intent(in) :: upstream_grad
141 real(real32), dimension(:,:), intent(out) :: output
142
143 integer :: i, c, s, num_dims, num_elements
144 3 real(real32), allocatable :: x_hat(:,:), dx_hat(:,:)
145 real(real32) :: mu, var, eps, norm
146 6 integer, dimension(size(this%shape)) :: input_shape
147
148
10/18
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 9 times.
✓ Branch 25 taken 3 times.
12 input_shape = this%left_operand%shape
149
150 select type(this)
151 type is (batchnorm_array_type)
152 3 eps = this%epsilon
153 3 num_dims = size(this%shape)
154
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✓ Branch 10 taken 6 times.
✓ Branch 11 taken 3 times.
9 num_elements = product(this%shape(1:num_dims - 1))
155
156
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 18 taken 7 times.
✓ Branch 19 taken 3 times.
✓ Branch 20 taken 1768 times.
✓ Branch 21 taken 7 times.
1778 output = 0._real32
157
158
15/30
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 22 taken 3 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 3 times.
3 allocate(x_hat(num_elements, size(upstream_grad,2)))
159
15/30
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 22 taken 3 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 3 times.
3 allocate(dx_hat(num_elements, size(upstream_grad,2)))
160 norm = real( &
161 product(input_shape(1:num_dims - 1)) * size(upstream_grad,2), &
162 real32 &
163
11/20
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✓ Branch 6 taken 6 times.
✓ Branch 7 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
9 )
164
165
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✓ Branch 8 taken 14 times.
✓ Branch 9 taken 3 times.
20 do c = 1, input_shape(num_dims)
166
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 mu = this%mean(c)
167
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 var = this%variance(c)
168
169 ! Normalised input
170 112 x_hat = ( &
171 this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - &
172 mu &
173
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 14 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 14 times.
✓ Branch 24 taken 14 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✓ Branch 32 taken 46 times.
✓ Branch 33 taken 14 times.
✓ Branch 34 taken 1768 times.
✓ Branch 35 taken 46 times.
1828 ) / sqrt(var + eps)
174
175 ! Gradient of normalised input
176 dx_hat = upstream_grad((c-1)*num_elements+1:c*num_elements,:) * &
177
18/34
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 14 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 14 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 14 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 14 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 14 times.
✓ Branch 26 taken 14 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 14 times.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✓ Branch 34 taken 46 times.
✓ Branch 35 taken 14 times.
✓ Branch 36 taken 1768 times.
✓ Branch 37 taken 46 times.
1828 this%right_operand%val(c,1)
178
179 ! Gradient wrt input
180
6/12
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
17 do concurrent(s = 1:size(upstream_grad,2), i = 1:num_elements)
181
2/4
✗ Branch 1 not taken.
✓ Branch 2 taken 1768 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1768 times.
3536 output(i + (c-1)*num_elements,s) = &
182 (1._real32 / (norm * sqrt(var + eps))) * &
183
48/84
✓ Branch 0 taken 1736 times.
✓ Branch 1 taken 14 times.
✓ Branch 2 taken 1768 times.
✓ Branch 3 taken 1736 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1768 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1768 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1768 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1768 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1768 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1768 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1768 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1768 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1768 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1768 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1768 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 1768 times.
✓ Branch 40 taken 1928 times.
✓ Branch 41 taken 1768 times.
✓ Branch 42 taken 798920 times.
✓ Branch 43 taken 1928 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1768 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 1768 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1768 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 1768 times.
✗ Branch 56 not taken.
✓ Branch 57 taken 1768 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 1768 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 1768 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 1768 times.
✗ Branch 68 not taken.
✓ Branch 69 taken 1768 times.
✗ Branch 71 not taken.
✓ Branch 72 taken 1768 times.
✗ Branch 74 not taken.
✓ Branch 75 taken 1768 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 1768 times.
✗ Branch 80 not taken.
✓ Branch 81 taken 1768 times.
✗ Branch 83 not taken.
✓ Branch 84 taken 1768 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 1768 times.
✗ Branch 89 not taken.
✓ Branch 90 taken 1768 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 1768 times.
✗ Branch 95 not taken.
✓ Branch 96 taken 1768 times.
✗ Branch 98 not taken.
✓ Branch 99 taken 1768 times.
✗ Branch 101 not taken.
✓ Branch 102 taken 1768 times.
✗ Branch 104 not taken.
✓ Branch 105 taken 1768 times.
✗ Branch 107 not taken.
✓ Branch 108 taken 1768 times.
✓ Branch 110 taken 1928 times.
✓ Branch 111 taken 1768 times.
✓ Branch 112 taken 798920 times.
✓ Branch 113 taken 1928 times.
✗ Branch 114 not taken.
✓ Branch 115 taken 1768 times.
✗ Branch 116 not taken.
✓ Branch 117 taken 1768 times.
1606982 (norm * dx_hat(i,s) - sum(dx_hat) - x_hat(i,s) * sum(dx_hat * x_hat))
184 end do
185 end do
186 end select
187
188
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 end subroutine get_partial_batchnorm_left_val
189 !-------------------------------------------------------------------------------
190
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
3 pure subroutine get_partial_batchnorm_right_val(this, upstream_grad, output)
191 !! Get partial derivative wrt params for batchnorm (subroutine version)
192 implicit none
193
194 class(array_type), intent(in) :: this
195 real(real32), dimension(:,:), intent(in) :: upstream_grad
196 real(real32), dimension(:,:), intent(out) :: output
197
198 integer :: c, num_dims, num_elements
199 3 real(real32), allocatable :: x_hat(:,:)
200 real(real32) :: mu, var, eps
201 6 integer, dimension(size(this%shape)) :: input_shape
202
203
10/18
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 3 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 3 times.
✓ Branch 24 taken 9 times.
✓ Branch 25 taken 3 times.
12 input_shape = this%left_operand%shape
204
205 select type(this)
206 type is (batchnorm_array_type)
207 3 eps = this%epsilon
208 3 num_dims = size(this%shape)
209
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✓ Branch 10 taken 6 times.
✓ Branch 11 taken 3 times.
9 num_elements = product(this%shape(1:num_dims - 1))
210
211
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 3 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 3 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 3 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 3 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 3 times.
✓ Branch 18 taken 7 times.
✓ Branch 19 taken 3 times.
✓ Branch 20 taken 92 times.
✓ Branch 21 taken 7 times.
102 output = 0._real32
212
213
15/30
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 3 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 3 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 3 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 3 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 3 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 3 times.
✓ Branch 22 taken 3 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 25 taken 3 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 3 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 3 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 3 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 3 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 3 times.
3 allocate(x_hat(num_elements, size(upstream_grad,2)))
214
215
5/8
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✓ Branch 8 taken 14 times.
✓ Branch 9 taken 3 times.
20 do c = 1, input_shape(num_dims)
216
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 mu = this%mean(c)
217
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
14 var = this%variance(c)
218
219 ! Normalised input
220 x_hat(:,:) = ( &
221 this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - mu &
222
22/40
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 14 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 14 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 14 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 14 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 14 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 14 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 14 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 14 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 14 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 14 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 14 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 14 times.
✓ Branch 46 taken 46 times.
✓ Branch 47 taken 14 times.
✓ Branch 48 taken 1768 times.
✓ Branch 49 taken 46 times.
1828 ) / sqrt(var + eps)
223
224
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 14 times.
28 output(c,1) = &
225
22/40
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 14 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 14 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 14 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 14 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 14 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 14 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 14 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 14 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 14 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 14 times.
✓ Branch 51 taken 46 times.
✓ Branch 52 taken 14 times.
✓ Branch 53 taken 1768 times.
✓ Branch 54 taken 46 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 14 times.
1842 sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:) * x_hat)
226
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 14 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 14 times.
42 output(c + input_shape(num_dims),1) = &
227
13/22
✗ Branch 0 not taken.
✓ Branch 1 taken 14 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 14 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 14 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 14 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 14 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 14 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 14 times.
✓ Branch 21 taken 46 times.
✓ Branch 22 taken 14 times.
✓ Branch 23 taken 1768 times.
✓ Branch 24 taken 46 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 14 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 14 times.
1873 sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:))
228
229 end do
230 end select
231
232
1/2
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
3 end subroutine get_partial_batchnorm_right_val
233 !###############################################################################
234
235
2/4
✓ Branch 0 taken 3 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
6 end submodule athena__diffstruc_extd_submodule_batchnorm
236