GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_sub_batchnorm.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 1 1 100.0%
Functions: 0 0 -%
Branches: 2 4 50.0%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_batchnorm
2 !! Submodule containing implementations for extended diffstruc array operations
3
4 contains
5
6 !###############################################################################
7 module function batchnorm_inference( &
8 input, params, mean, variance, epsilon &
9 ) result( output )
10 implicit none
11 class(array_type), intent(in), target :: input
12 class(array_type), intent(in), target :: params
13 real(real32), dimension(:), intent(in) :: mean
14 real(real32), dimension(:), intent(in) :: variance
15 real(real32), intent(in) :: epsilon
16 type(batchnorm_array_type), pointer :: output
17
18 integer :: i, c, s
19 integer :: num_elements, num_dims
20
21 call output%allocate(array_shape = input%shape)
22 output%epsilon = epsilon
23 output%mean = mean
24 output%variance = variance
25 num_dims = size(input%shape)
26 num_elements = product(input%shape(1:num_dims - 1))
27 do concurrent(c = 1:input%shape(num_dims))
28 do concurrent(i=1:num_elements)
29 output%val(i + (c-1) * num_elements:i + c * num_elements - 1, s) = &
30 params%val(c,1) * ( &
31 input%val(i + (c-1) * num_elements:i + c * num_elements - 1, s) &
32 - mean(c) &
33 ) / &
34 sqrt(variance(c) + output%epsilon) + &
35 params%val(c+input%shape(num_dims),1)
36 end do
37 end do
38
39 end function batchnorm_inference
40 !-------------------------------------------------------------------------------
41 module function batchnorm( &
42 input, params, momentum, mean, variance, epsilon &
43 ) result( output )
44 !! Batch normalisation operation
45 implicit none
46
47 ! Arguments
48 class(array_type), intent(in), target :: input
49 class(array_type), intent(in), target :: params
50 real(real32), intent(in) :: momentum
51 real(real32), dimension(:), intent(in) :: mean
52 real(real32), dimension(:), intent(in) :: variance
53 real(real32), intent(in) :: epsilon
54 type(batchnorm_array_type), pointer :: output
55
56 ! Local variables
57 integer :: i, c, s
58 integer :: num_elements, num_dims
59 real(real32) :: mu, var, norm
60
61 allocate(output)
62 if(output%allocated) call output%deallocate()
63 call output%allocate(array_shape = [ input%shape, size(input%val,2) ])
64 output%epsilon = epsilon
65 output%mean = mean
66 output%variance = variance
67 num_dims = size(input%shape)
68 num_elements = product(input%shape(1:num_dims - 1))
69 norm = real(num_elements * size(input%val,2), real32)
70 do concurrent(c = 1:input%shape(num_dims))
71 mu = 0._real32
72 var = 0._real32
73 mu = sum(input%val((c-1) * num_elements+1:c*num_elements,:)) / norm
74 var = sum( (input%val((c-1) * num_elements+1:c*num_elements,:) - mu) ** 2 ) / &
75 norm
76
77 if(momentum .gt. 1.E-8_real32) then
78 output%mean(c) = momentum * mean(c) + (1._real32 - momentum) * mu
79 output%variance(c) = momentum * variance(c) + (1._real32 - momentum) * var
80 else
81 output%mean(c) = mu
82 output%variance(c) = var
83 end if
84
85 do concurrent(s = 1:size(input%val,2), i = 1:num_elements)
86 output%val(i + (c-1) * num_elements, s) = &
87 params%val(c,1) * ( input%val(i + (c-1) * num_elements, s) - mu ) / &
88 sqrt(var + output%epsilon) + params%val(c+input%shape(num_dims),1)
89 end do
90 end do
91
92 output%get_partial_left => get_partial_batchnorm_left
93 output%get_partial_left_val => get_partial_batchnorm_left_val
94 output%get_partial_right => get_partial_batchnorm_right
95 output%get_partial_right_val => get_partial_batchnorm_right_val
96 if(input%requires_grad .or. params%requires_grad) then
97 output%requires_grad = .true.
98 output%is_forward = input%is_forward .or. params%is_forward
99 output%operation = 'batchnorm'
100 output%left_operand => input
101 output%right_operand => params
102 end if
103
104 end function batchnorm
105 !-------------------------------------------------------------------------------
106 function get_partial_batchnorm_left(this, upstream_grad) result(output)
107 implicit none
108 class(array_type), intent(inout) :: this
109 type(array_type), intent(in) :: upstream_grad
110 type(array_type) :: output
111
112 class(array_type), pointer :: input
113
114 input => this%left_operand
115 call output%allocate(array_shape = [ input%shape, size(upstream_grad%val,2) ])
116
117 call this%get_partial_left_val(upstream_grad%val, output%val)
118
119 end function get_partial_batchnorm_left
120 !-------------------------------------------------------------------------------
121 function get_partial_batchnorm_right(this, upstream_grad) result(output)
122 implicit none
123 class(array_type), intent(inout) :: this
124 type(array_type), intent(in) :: upstream_grad
125 type(array_type) :: output
126
127 class(array_type), pointer :: params
128
129 params => this%right_operand
130 call output%allocate(array_shape = [ params%shape, 1 ])
131
132 call this%get_partial_right_val(upstream_grad%val, output%val)
133
134 end function get_partial_batchnorm_right
135 !-------------------------------------------------------------------------------
136 pure subroutine get_partial_batchnorm_left_val(this, upstream_grad, output)
137 !! Get partial derivative wrt input for batchnorm (subroutine version)
138 implicit none
139
140 class(array_type), intent(in) :: this
141 real(real32), dimension(:,:), intent(in) :: upstream_grad
142 real(real32), dimension(:,:), intent(out) :: output
143
144 integer :: i, c, s, num_dims, num_elements
145 real(real32), allocatable :: x_hat(:,:), dx_hat(:,:)
146 real(real32) :: mu, var, eps, norm
147 integer, dimension(size(this%shape)) :: input_shape
148
149 input_shape = this%left_operand%shape
150
151 select type(this)
152 type is (batchnorm_array_type)
153 eps = this%epsilon
154 num_dims = size(this%shape)
155 num_elements = product(this%shape(1:num_dims - 1))
156
157 output = 0._real32
158
159 allocate(x_hat(num_elements, size(upstream_grad,2)))
160 allocate(dx_hat(num_elements, size(upstream_grad,2)))
161 norm = real( &
162 product(input_shape(1:num_dims - 1)) * size(upstream_grad,2), &
163 real32 &
164 )
165
166 do c = 1, input_shape(num_dims)
167 mu = this%mean(c)
168 var = this%variance(c)
169
170 ! Normalised input
171 x_hat = ( &
172 this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - &
173 mu &
174 ) / sqrt(var + eps)
175
176 ! Gradient of normalised input
177 dx_hat = upstream_grad((c-1)*num_elements+1:c*num_elements,:) * &
178 this%right_operand%val(c,1)
179
180 ! Gradient wrt input
181 do concurrent(s = 1:size(upstream_grad,2), i = 1:num_elements)
182 output(i + (c-1)*num_elements,s) = &
183 (1._real32 / (norm * sqrt(var + eps))) * &
184 (norm * dx_hat(i,s) - sum(dx_hat) - x_hat(i,s) * sum(dx_hat * x_hat))
185 end do
186 end do
187 end select
188
189 end subroutine get_partial_batchnorm_left_val
190 !-------------------------------------------------------------------------------
191 pure subroutine get_partial_batchnorm_right_val(this, upstream_grad, output)
192 !! Get partial derivative wrt params for batchnorm (subroutine version)
193 implicit none
194
195 class(array_type), intent(in) :: this
196 real(real32), dimension(:,:), intent(in) :: upstream_grad
197 real(real32), dimension(:,:), intent(out) :: output
198
199 integer :: c, num_dims, num_elements
200 real(real32), allocatable :: x_hat(:,:)
201 real(real32) :: mu, var, eps
202 integer, dimension(size(this%shape)) :: input_shape
203
204 input_shape = this%left_operand%shape
205
206 select type(this)
207 type is (batchnorm_array_type)
208 eps = this%epsilon
209 num_dims = size(this%shape)
210 num_elements = product(this%shape(1:num_dims - 1))
211
212 output = 0._real32
213
214 allocate(x_hat(num_elements, size(upstream_grad,2)))
215
216 do c = 1, input_shape(num_dims)
217 mu = this%mean(c)
218 var = this%variance(c)
219
220 ! Normalised input
221 x_hat(:,:) = ( &
222 this%left_operand%val((c-1)*num_elements+1:c*num_elements,:) - mu &
223 ) / sqrt(var + eps)
224
225 output(c,1) = &
226 sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:) * x_hat)
227 output(c + input_shape(num_dims),1) = &
228 sum(upstream_grad((c-1)*num_elements+1:c*num_elements,:))
229
230 end do
231 end select
232
233 end subroutine get_partial_batchnorm_right_val
234 !###############################################################################
235
236
2/4
✓ Branch 0 (27→28) taken 3 times.
✗ Branch 1 (27→335) not taken.
✓ Branch 2 (27→28) taken 3 times.
✗ Branch 3 (27→226) not taken.
6 end submodule athena__diffstruc_extd_submodule_batchnorm
237