GCC Code Coverage Report


Directory: src/athena/
File: src/athena/athena_orthogonal_attention_layer.f90
Date: 2026-04-15 16:08:59
Exec Total Coverage
Lines: 306 346 88.4%
Functions: 0 0 -%
Branches: 920 1925 47.8%

Line Branch Exec Source
1 module athena__orthogonal_attention_layer
2 !! Module containing implementation of an Orthogonal Attention layer
3 !!
4 !! This module implements the Orthogonal Attention mechanism from
5 !! "Improved Operator Learning by Orthogonal Attention" (Luo et al., 2024).
6 !!
7 !! Instead of softmax attention, this layer projects queries and keys
8 !! onto a learned orthonormal basis of dimension \(k \ll N\), giving
9 !! a linear-cost approximation to the attention kernel.
10 !!
11 !! Given input \(\mathbf{u} \in \mathbb{R}^{n_{in}}\):
12 !!
13 !! \[
14 !! \mathbf{Q} = \mathbf{W}_Q\,\mathbf{u}, \quad
15 !! \mathbf{K} = \mathbf{W}_K\,\mathbf{u}, \quad
16 !! \mathbf{V} = \mathbf{W}_V\,\mathbf{u}
17 !! \]
18 !!
19 !! The orthogonal basis \(\mathbf{\Phi} \in \mathbb{R}^{n_{in} \times k}\)
20 !! is obtained by QR decomposition of learnable weights
21 !! \(\mathbf{B} \in \mathbb{R}^{n_{in} \times k}\).
22 !!
23 !! The attention output is:
24 !! \[
25 !! \text{Attn}(\mathbf{u}) = \mathbf{\Phi}\,
26 !! (\mathbf{\Phi}^T \mathbf{Q})^T\,
27 !! (\mathbf{\Phi}^T \mathbf{K})\,
28 !! \mathbf{V}
29 !! \]
30 !!
31 !! The layer output is:
32 !! \[
33 !! \mathbf{v} = \sigma\!\bigl(
34 !! \text{Attn}(\mathbf{u}) + \mathbf{W}\,\mathbf{u} + \mathbf{b}
35 !! \bigr)
36 !! \]
37 !!
38 !! Parameters (learnable):
39 !! - \(\mathbf{W}_Q \in \mathbb{R}^{d_k \times n_{in}}\)
40 !! - \(\mathbf{W}_K \in \mathbb{R}^{d_k \times n_{in}}\)
41 !! - \(\mathbf{W}_V \in \mathbb{R}^{n_{out} \times n_{in}}\)
42 !! - \(\mathbf{B} \in \mathbb{R}^{n_{in} \times k}\) (basis, orthogonalised)
43 !! - \(\mathbf{W} \in \mathbb{R}^{n_{out} \times n_{in}}\) (bypass)
44 !! - \(\mathbf{b} \in \mathbb{R}^{n_{out}}\) (optional bias)
45 use coreutils, only: real32, stop_program
46 use athena__base_layer, only: learnable_layer_type, base_layer_type
47 use athena__misc_types, only: base_actv_type, base_init_type, &
48 onnx_attribute_type
49 use diffstruc, only: array_type, matmul, operator(+), operator(*), tanh
50 use athena__diffstruc_extd, only: ono_encode, ono_decode, softmax
51 implicit none
52
53
54 private
55
56 public :: orthogonal_attention_layer_type
57 public :: read_orthogonal_attention_layer
58
59
60 type, extends(learnable_layer_type) :: orthogonal_attention_layer_type
61 !! Type for an Orthogonal Attention layer
62 integer :: num_inputs = 0
63 !! Number of input features / discretisation points
64 integer :: num_outputs = 0
65 !! Number of output features / discretisation points
66 integer :: num_basis = 0
67 !! Number of orthogonal basis functions (k)
68 integer :: key_dim = 0
69 !! Dimension of query/key projections (d_k)
70 type(array_type), dimension(1) :: z
71 !! Temporary array for pre-activation values
72 contains
73 procedure, pass(this) :: get_num_params => get_num_params_ono_attn
74 procedure, pass(this) :: set_hyperparams => set_hyperparams_ono_attn
75 procedure, pass(this) :: init => init_ono_attn
76 procedure, pass(this) :: print_to_unit => print_to_unit_ono_attn
77 procedure, pass(this) :: read => read_ono_attn
78
79 procedure, pass(this) :: forward => forward_ono_attn
80 procedure, pass(this) :: get_bases => get_bases_ono_attn
81 procedure, pass(this) :: get_attributes => get_attributes_ono_attn
82
83 final :: finalise_ono_attn
84 end type orthogonal_attention_layer_type
85
86 interface orthogonal_attention_layer_type
87 module function layer_setup( &
88 num_outputs, num_basis, key_dim, &
89 num_inputs, use_bias, &
90 activation, &
91 kernel_initialiser, bias_initialiser, verbose &
92 ) result(layer)
93 integer, intent(in) :: num_outputs
94 integer, intent(in) :: num_basis
95 integer, optional, intent(in) :: key_dim
96 integer, optional, intent(in) :: num_inputs
97 logical, optional, intent(in) :: use_bias
98 class(*), optional, intent(in) :: activation
99 class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
100 integer, optional, intent(in) :: verbose
101 type(orthogonal_attention_layer_type) :: layer
102 end function layer_setup
103 end interface orthogonal_attention_layer_type
104
105
106
107 contains
108
109 !###############################################################################
110 4 subroutine finalise_ono_attn(this)
111 !! Finalise the orthogonal attention layer
112 implicit none
113
114 ! Arguments
115 type(orthogonal_attention_layer_type), intent(inout) :: this
116 !! Layer instance to release
117
118
2/4
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
4 if(allocated(this%input_shape)) deallocate(this%input_shape)
119
3/6
✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✓ Branch 5 taken 4 times.
✗ Branch 6 not taken.
4 if(allocated(this%output)) deallocate(this%output)
120
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 3 times.
4 if(this%z(1)%allocated) call this%z(1)%deallocate()
121
122 4 end subroutine finalise_ono_attn
123 !###############################################################################
124
125
126 !###############################################################################
127 7 pure function get_num_params_ono_attn(this) result(num_params)
128 !! Return the number of learnable parameters for the layer
129 implicit none
130
131 ! Arguments
132 class(orthogonal_attention_layer_type), intent(in) :: this
133 !! Layer instance
134 integer :: num_params
135 !! Total number of learnable parameters
136
137 ! W_Q: key_dim * num_inputs
138 ! W_K: key_dim * num_inputs
139 ! W_V: num_outputs * num_inputs
140 ! B: num_inputs * num_basis (basis weights to orthogonalise)
141 ! W: num_outputs * num_inputs (bypass)
142 ! b: num_outputs (optional)
143 num_params = this%key_dim * this%num_inputs + & ! W_Q
144 this%key_dim * this%num_inputs + & ! W_K
145 this%num_outputs * this%num_inputs + & ! W_V
146 this%num_inputs * this%num_basis + & ! B
147 7 this%num_outputs * this%num_inputs ! W
148
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(this%use_bias) num_params = num_params + this%num_outputs
149
150 7 end function get_num_params_ono_attn
151 !###############################################################################
152
153
154 !###############################################################################
155 7 module function layer_setup( &
156 num_outputs, num_basis, key_dim, &
157 num_inputs, use_bias, &
158 activation, &
159 kernel_initialiser, bias_initialiser, verbose &
160
9/16
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 7 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 7 times.
14 ) result(layer)
161 use athena__activation, only: activation_setup
162 use athena__initialiser, only: initialiser_setup
163 implicit none
164
165 ! Arguments
166 integer, intent(in) :: num_outputs
167 !! Number of output features
168 integer, intent(in) :: num_basis
169 !! Number of orthogonal basis vectors
170 integer, optional, intent(in) :: key_dim
171 !! Query/key projection dimension
172 integer, optional, intent(in) :: num_inputs
173 !! Number of input features when known at construction time
174 logical, optional, intent(in) :: use_bias
175 !! Whether to allocate a bias term
176 class(*), optional, intent(in) :: activation
177 !! Activation function specification
178 class(*), optional, intent(in) :: kernel_initialiser, bias_initialiser
179 !! Kernel and bias initialiser specifications
180 integer, optional, intent(in) :: verbose
181 !! Verbosity level
182
183 type(orthogonal_attention_layer_type) :: layer
184 !! Constructed orthogonal attention layer
185
186 ! Local variables
187 integer :: verbose_ = 0
188 !! Effective verbosity level
189 integer :: key_dim_
190 !! Query/key projection dimension after defaults
191 logical :: use_bias_ = .true.
192 !! Effective bias flag
193 21 class(base_actv_type), allocatable :: activation_
194 !! Materialised activation object
195 7 class(base_init_type), allocatable :: kernel_initialiser_, bias_initialiser_
196 !! Materialised kernel and bias initialisers
197
198
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(present(verbose)) verbose_ = verbose
199
2/2
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 4 times.
7 if(present(use_bias)) use_bias_ = use_bias
200 7 key_dim_ = num_basis
201
2/2
✓ Branch 0 taken 2 times.
✓ Branch 1 taken 5 times.
7 if(present(key_dim)) key_dim_ = key_dim
202
203
3/4
✓ Branch 0 taken 3 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 3 times.
✗ Branch 3 not taken.
7 if(present(activation))then
204
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 3 times.
✓ Branch 17 taken 3 times.
✗ Branch 18 not taken.
3 activation_ = activation_setup(activation)
205 else
206
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 4 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 4 times.
✓ Branch 17 taken 4 times.
✗ Branch 18 not taken.
4 activation_ = activation_setup("none")
207 end if
208
209
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
7 if(present(kernel_initialiser))then
210 kernel_initialiser_ = initialiser_setup(kernel_initialiser)
211 end if
212
1/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
7 if(present(bias_initialiser))then
213 bias_initialiser_ = initialiser_setup(bias_initialiser)
214 end if
215
216 call layer%set_hyperparams( &
217 num_outputs = num_outputs, &
218 num_basis = num_basis, &
219 key_dim = key_dim_, &
220 use_bias = use_bias_, &
221 activation = activation_, &
222 kernel_initialiser = kernel_initialiser_, &
223 bias_initialiser = bias_initialiser_, &
224 verbose = verbose_ &
225 7 )
226
227
4/4
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 5 times.
✓ Branch 3 taken 5 times.
12 if(present(num_inputs)) call layer%init(input_shape=[num_inputs])
228
229
13/28
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 7 times.
✓ Branch 2 taken 7 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 7 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 7 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✓ Branch 26 taken 7 times.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✓ Branch 29 taken 7 times.
21 end function layer_setup
230 !###############################################################################
231
232
233 !###############################################################################
234 8 subroutine set_hyperparams_ono_attn( &
235 this, num_outputs, num_basis, key_dim, &
236 use_bias, &
237 activation, &
238 kernel_initialiser, bias_initialiser, &
239 verbose &
240 )
241 use athena__activation, only: activation_setup
242 use athena__initialiser, only: get_default_initialiser, initialiser_setup
243 implicit none
244
245 ! Arguments
246 class(orthogonal_attention_layer_type), intent(inout) :: this
247 !! Layer instance to configure
248 integer, intent(in) :: num_outputs
249 !! Number of output features
250 integer, intent(in) :: num_basis
251 !! Number of orthogonal basis vectors
252 integer, intent(in) :: key_dim
253 !! Query/key projection dimension
254 logical, intent(in) :: use_bias
255 !! Whether to use a bias term
256 class(base_actv_type), allocatable, intent(in) :: activation
257 !! Activation function object
258 class(base_init_type), allocatable, intent(in) :: &
259 kernel_initialiser, bias_initialiser
260 !! Kernel and bias initialiser objects
261 integer, optional, intent(in) :: verbose
262 !! Verbosity level
263
264 ! Local variables
265 character(len=256) :: buffer
266 !! Buffer for default initialiser lookup
267
268
5/8
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 8 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
8 this%name = "orthogonal_attention"
269 8 this%type = "nop"
270 8 this%input_rank = 1
271 8 this%output_rank = 1
272 8 this%use_bias = use_bias
273 8 this%num_outputs = num_outputs
274 8 this%num_basis = num_basis
275 8 this%key_dim = key_dim
276
277
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
8 if(allocated(this%activation)) deallocate(this%activation)
278
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if(.not.allocated(activation))then
279 this%activation = activation_setup("none")
280 else
281
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
8 allocate(this%activation, source=activation)
282 end if
283
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
8 if(allocated(this%kernel_init)) deallocate(this%kernel_init)
284
2/2
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 times.
8 if(.not.allocated(kernel_initialiser))then
285
1/2
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
7 buffer = get_default_initialiser(this%activation%name)
286
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 7 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 7 times.
✓ Branch 17 taken 7 times.
✗ Branch 18 not taken.
7 this%kernel_init = initialiser_setup(buffer)
287 else
288
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(this%kernel_init, source=kernel_initialiser)
289 end if
290
4/6
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
8 if(allocated(this%bias_init)) deallocate(this%bias_init)
291
2/2
✓ Branch 0 taken 7 times.
✓ Branch 1 taken 1 times.
8 if(.not.allocated(bias_initialiser))then
292 buffer = get_default_initialiser( &
293 this%activation%name, &
294 is_bias=.true. &
295
1/2
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
7 )
296
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 7 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 7 times.
✓ Branch 17 taken 7 times.
✗ Branch 18 not taken.
7 this%bias_init = initialiser_setup(buffer)
297 else
298
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
1 allocate(this%bias_init, source=bias_initialiser)
299 end if
300
301
1/2
✓ Branch 0 taken 8 times.
✗ Branch 1 not taken.
8 if(present(verbose))then
302
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
8 if(abs(verbose).gt.0)then
303 write(*,'("ORTHOGONAL_ATTENTION activation: ",A)') &
304 trim(this%activation%name)
305 end if
306 end if
307
308 8 end subroutine set_hyperparams_ono_attn
309 !###############################################################################
310
311
312 !###############################################################################
313
1/2
✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
7 subroutine init_ono_attn(this, input_shape, verbose)
314 !! Initialise parameter storage and output buffers for the layer
315 implicit none
316
317 ! Arguments
318 class(orthogonal_attention_layer_type), intent(inout) :: this
319 !! Layer instance to initialise
320 integer, dimension(:), intent(in) :: input_shape
321 !! Input shape used to infer num_inputs
322 integer, optional, intent(in) :: verbose
323 !! Verbosity level
324
325 ! Local variables
326 integer :: num_inputs, idx, nparams
327 !! Effective fan-in size and reserved scratch integers
328 integer :: verbose_ = 0
329 !! Effective verbosity level
330
331
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(present(verbose)) verbose_ = verbose
332
333
334 !---------------------------------------------------------------------------
335 ! Set shapes
336 !---------------------------------------------------------------------------
337
4/8
✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
7 if(.not.allocated(this%input_shape)) call this%set_shape(input_shape)
338
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%num_inputs = this%input_shape(1)
339
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 7 times.
✓ Branch 7 taken 7 times.
14 this%output_shape = [this%num_outputs]
340 7 this%num_params = this%get_num_params()
341
342
343 !---------------------------------------------------------------------------
344 ! Allocate learnable parameters
345 !
346 ! params(1): W_Q query projection [key_dim x num_inputs]
347 ! params(2): W_K key projection [key_dim x num_inputs]
348 ! params(3): W_V value projection [num_outputs x num_inputs]
349 ! params(4): B basis weights [num_inputs x num_basis]
350 ! params(5): W bypass weights [num_outputs x num_inputs]
351 ! params(6): b bias [num_outputs] (optional)
352 !---------------------------------------------------------------------------
353
3/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
7 allocate(this%weight_shape(2,5))
354
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 14 times.
✓ Branch 22 taken 7 times.
21 this%weight_shape(:,1) = [ this%key_dim, this%num_inputs ]
355
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 14 times.
✓ Branch 22 taken 7 times.
21 this%weight_shape(:,2) = [ this%key_dim, this%num_inputs ]
356
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 14 times.
✓ Branch 22 taken 7 times.
21 this%weight_shape(:,3) = [ this%num_outputs, this%num_inputs ]
357
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 14 times.
✓ Branch 22 taken 7 times.
21 this%weight_shape(:,4) = [ this%num_inputs, this%num_basis ]
358
9/16
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 14 times.
✓ Branch 22 taken 7 times.
21 this%weight_shape(:,5) = [ this%num_outputs, this%num_inputs ]
359
360
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(this%use_bias)then
361
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 6 times.
✓ Branch 7 taken 6 times.
12 this%bias_shape = [ this%num_outputs ]
362
16/30
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✓ Branch 21 taken 36 times.
✓ Branch 22 taken 6 times.
✓ Branch 23 taken 36 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 36 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 36 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 36 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 36 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 36 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 36 times.
42 allocate(this%params(6))
363 else
364
16/30
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 5 times.
✓ Branch 22 taken 1 times.
✓ Branch 23 taken 5 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 35 not taken.
✓ Branch 36 taken 5 times.
6 allocate(this%params(5))
365 end if
366
367 7 num_inputs = this%num_inputs
368
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(this%use_bias) num_inputs = this%num_inputs + 1
369
370 ! W_Q
371
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 21 times.
✓ Branch 7 taken 7 times.
28 call this%params(1)%allocate([this%key_dim, this%num_inputs, 1])
372
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 call this%params(1)%set_requires_grad(.true.)
373
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(1)%fix_pointer = .true.
374
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(1)%is_sample_dependent = .false.
375
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(1)%is_temporary = .false.
376
377 ! W_K
378
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 21 times.
✓ Branch 7 taken 7 times.
28 call this%params(2)%allocate([this%key_dim, this%num_inputs, 1])
379
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 call this%params(2)%set_requires_grad(.true.)
380
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(2)%fix_pointer = .true.
381
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(2)%is_sample_dependent = .false.
382
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(2)%is_temporary = .false.
383
384 ! W_V
385
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 21 times.
✓ Branch 7 taken 7 times.
28 call this%params(3)%allocate([this%num_outputs, this%num_inputs, 1])
386
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 call this%params(3)%set_requires_grad(.true.)
387
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(3)%fix_pointer = .true.
388
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(3)%is_sample_dependent = .false.
389
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(3)%is_temporary = .false.
390
391 ! B (basis weights)
392
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 21 times.
✓ Branch 7 taken 7 times.
28 call this%params(4)%allocate([this%num_inputs, this%num_basis, 1])
393
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 call this%params(4)%set_requires_grad(.true.)
394
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(4)%fix_pointer = .true.
395
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(4)%is_sample_dependent = .false.
396
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(4)%is_temporary = .false.
397
398 ! W (bypass)
399
4/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✓ Branch 6 taken 21 times.
✓ Branch 7 taken 7 times.
28 call this%params(5)%allocate([this%num_outputs, this%num_inputs, 1])
400
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 call this%params(5)%set_requires_grad(.true.)
401
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(5)%fix_pointer = .true.
402
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(5)%is_sample_dependent = .false.
403
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
7 this%params(5)%is_temporary = .false.
404
405 ! b (bias, optional)
406
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(this%use_bias)then
407
12/20
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✓ Branch 21 taken 6 times.
✓ Branch 22 taken 6 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 6 times.
✓ Branch 25 taken 12 times.
✓ Branch 26 taken 6 times.
24 call this%params(6)%allocate([this%bias_shape, 1])
408
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 call this%params(6)%set_requires_grad(.true.)
409
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 this%params(6)%fix_pointer = .true.
410
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 this%params(6)%is_sample_dependent = .false.
411
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
6 this%params(6)%is_temporary = .false.
412 end if
413
414
415 !---------------------------------------------------------------------------
416 ! Initialise learnable parameters
417 !---------------------------------------------------------------------------
418 call this%kernel_init%initialise( &
419 70 this%params(1)%val(:,1), &
420 fan_in = this%num_inputs, fan_out = this%key_dim, &
421 spacing = [ this%key_dim ] &
422
12/22
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✓ Branch 30 taken 7 times.
✓ Branch 31 taken 7 times.
14 )
423 call this%kernel_init%initialise( &
424 70 this%params(2)%val(:,1), &
425 fan_in = this%num_inputs, fan_out = this%key_dim, &
426 spacing = [ this%key_dim ] &
427
12/22
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✓ Branch 30 taken 7 times.
✓ Branch 31 taken 7 times.
14 )
428 call this%kernel_init%initialise( &
429 70 this%params(3)%val(:,1), &
430 fan_in = num_inputs, fan_out = this%num_outputs, &
431 spacing = [ this%num_outputs ] &
432
12/22
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✓ Branch 30 taken 7 times.
✓ Branch 31 taken 7 times.
14 )
433 call this%kernel_init%initialise( &
434 70 this%params(4)%val(:,1), &
435 fan_in = this%num_inputs, fan_out = this%num_basis, &
436 spacing = [ this%num_inputs ] &
437
12/22
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✓ Branch 30 taken 7 times.
✓ Branch 31 taken 7 times.
14 )
438 call this%kernel_init%initialise( &
439 70 this%params(5)%val(:,1), &
440 fan_in = num_inputs, fan_out = this%num_outputs, &
441 spacing = [ this%num_outputs ] &
442
12/22
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✓ Branch 30 taken 7 times.
✓ Branch 31 taken 7 times.
14 )
443
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 1 times.
7 if(this%use_bias)then
444 call this%bias_init%initialise( &
445 60 this%params(6)%val(:,1), &
446 fan_in = num_inputs, fan_out = this%num_outputs &
447
10/20
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
6 )
448 end if
449
450
451 !---------------------------------------------------------------------------
452 ! Allocate output arrays
453 !---------------------------------------------------------------------------
454
1/6
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
7 if(allocated(this%output)) deallocate(this%output)
455
15/26
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 7 times.
✓ Branch 33 taken 7 times.
✓ Branch 34 taken 7 times.
✓ Branch 35 taken 7 times.
✓ Branch 36 taken 7 times.
21 allocate(this%output(1,1))
456
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
7 if(this%z(1)%allocated) call this%z(1)%deallocate()
457
458 7 end subroutine init_ono_attn
459 !###############################################################################
460
461
462 !###############################################################################
463 2 function get_bases_ono_attn(this) result(phi)
464 !! Orthogonalise the basis matrix B using modified Gram-Schmidt
465 implicit none
466
467 ! Arguments
468 class(orthogonal_attention_layer_type), intent(in) :: this
469 !! Layer instance providing basis parameters
470 type(array_type) :: phi
471 !! Orthogonalised basis matrix packed in an array_type
472
473 ! Local variables
474 integer :: n, k, i, j
475 !! Basis dimensions and Gram-Schmidt loop indices
476 2 real(real32), allocatable :: B(:,:), Q(:,:)
477 !! Raw basis matrix and orthogonalised copy
478 real(real32) :: norm_val, proj
479 !! Gram-Schmidt norm and projection scalars
480
481 2 n = this%num_inputs
482 2 k = this%num_basis
483
484
17/34
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 2 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 2 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 2 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 2 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 2 times.
✓ Branch 19 taken 2 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 23 taken 2 times.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 2 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 2 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 2 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 2 times.
2 allocate(B(n, k), Q(n, k))
485
486 ! Reshape B from flat params(4) into [n, k]
487
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✓ Branch 30 taken 4 times.
✓ Branch 31 taken 2 times.
✓ Branch 33 taken 2 times.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 2 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 2 times.
6 B = reshape(this%params(4)%val(:,1), [n, k])
488
489 ! Modified Gram-Schmidt orthogonalisation
490
15/36
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✓ Branch 24 taken 2 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✓ Branch 41 taken 2 times.
✓ Branch 42 taken 48 times.
✓ Branch 43 taken 6 times.
56 Q = B
491
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 do j = 1, k
492 ! Subtract projections of previous orthogonal vectors
493
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 6 times.
12 do i = 1, j - 1
494
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✓ Branch 39 taken 48 times.
✓ Branch 40 taken 6 times.
54 proj = dot_product(Q(:,i), Q(:,j))
495
22/42
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 6 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 6 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 6 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 6 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 6 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 6 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 6 times.
✓ Branch 60 taken 48 times.
✓ Branch 61 taken 6 times.
60 Q(:,j) = Q(:,j) - proj * Q(:,i)
496 end do
497 ! Normalise
498
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✓ Branch 39 taken 48 times.
✓ Branch 40 taken 6 times.
54 norm_val = sqrt(dot_product(Q(:,j), Q(:,j)))
499
1/2
✓ Branch 0 taken 6 times.
✗ Branch 1 not taken.
8 if(norm_val .gt. 1.0e-12_real32)then
500
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 6 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 6 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 6 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 6 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 6 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 6 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 6 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 6 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 6 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 6 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 6 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 6 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 6 times.
✓ Branch 39 taken 48 times.
✓ Branch 40 taken 6 times.
54 Q(:,j) = Q(:,j) / norm_val
501 else
502 Q(:,j) = 0.0_real32
503 end if
504 end do
505
506 ! Store in phi as a fixed array_type
507
2/2
✓ Branch 0 taken 6 times.
✓ Branch 1 taken 2 times.
8 call phi%allocate([n, k, 1])
508 2 phi%is_sample_dependent = .false.
509 2 phi%requires_grad = .false.
510 2 phi%fix_pointer = .true.
511 2 phi%is_temporary = .false.
512
10/16
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✓ Branch 21 taken 48 times.
✓ Branch 22 taken 2 times.
52 phi%val(:,1) = reshape(Q, [n * k])
513
514
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 deallocate(B, Q)
515
516
3/6
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 2 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 2 times.
4 end function get_bases_ono_attn
517 !###############################################################################
518
519
520 !###############################################################################
521 1 subroutine print_to_unit_ono_attn(this, unit)
522 !! Print orthogonal attention layer settings and parameters to a unit
523 use coreutils, only: to_upper
524 implicit none
525
526 ! Arguments
527 class(orthogonal_attention_layer_type), intent(in) :: this
528 !! Layer instance to print
529 integer, intent(in) :: unit
530 !! Output unit number
531
532 1 write(unit,'(3X,"NUM_INPUTS = ",I0)') this%num_inputs
533 1 write(unit,'(3X,"NUM_OUTPUTS = ",I0)') this%num_outputs
534 1 write(unit,'(3X,"NUM_BASIS = ",I0)') this%num_basis
535 1 write(unit,'(3X,"KEY_DIM = ",I0)') this%key_dim
536 1 write(unit,'(3X,"USE_BIAS = ",L1)') this%use_bias
537
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%activation%name .ne. 'none')then
538 1 call this%activation%print_to_unit(unit)
539 end if
540
541 1 write(unit,'("WEIGHTS")')
542
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 20 times.
✓ Branch 26 taken 1 times.
21 write(unit,'(5(E16.8E2))') this%params(1)%val(:,1) ! W_Q
543
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 20 times.
✓ Branch 26 taken 1 times.
21 write(unit,'(5(E16.8E2))') this%params(2)%val(:,1) ! W_K
544
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 15 times.
✓ Branch 26 taken 1 times.
16 write(unit,'(5(E16.8E2))') this%params(3)%val(:,1) ! W_V
545
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 10 times.
✓ Branch 26 taken 1 times.
11 write(unit,'(5(E16.8E2))') this%params(4)%val(:,1) ! B
546
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 15 times.
✓ Branch 26 taken 1 times.
16 write(unit,'(5(E16.8E2))') this%params(5)%val(:,1) ! W
547
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(this%use_bias)then
548
10/18
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 1 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 1 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 1 times.
✓ Branch 25 taken 3 times.
✓ Branch 26 taken 1 times.
4 write(unit,'(5(E16.8E2))') this%params(6)%val(:,1) ! b
549 end if
550 1 write(unit,'("END WEIGHTS")')
551
552 1 end subroutine print_to_unit_ono_attn
553 !###############################################################################
554
555
556 !###############################################################################
557 1 subroutine read_ono_attn(this, unit, verbose)
558 use athena__tools_infile, only: assign_val, assign_vec, move
559 use coreutils, only: to_lower, to_upper, icount
560 use athena__activation, only: read_activation
561 use athena__initialiser, only: initialiser_setup
562 implicit none
563
564 ! Arguments
565 class(orthogonal_attention_layer_type), intent(inout) :: this
566 !! Layer instance to populate from file data
567 integer, intent(in) :: unit
568 !! Input unit number
569 integer, optional, intent(in) :: verbose
570 !! Verbosity level
571
572 ! Local variables
573 integer :: stat, verbose_ = 0
574 !! I/O status and effective verbosity level
575 integer :: j, k, c, itmp1, iline
576 !! Loop counters and parser scratch integers
577 integer :: num_inputs, num_outputs, num_basis, key_dim
578 !! Parsed layer dimensions
579 logical :: use_bias = .true.
580 !! Parsed bias flag
581 character(14) :: kernel_initialiser_name='', bias_initialiser_name=''
582 !! Parsed initialiser names
583 3 class(base_actv_type), allocatable :: activation
584 !! Parsed activation object
585 5 class(base_init_type), allocatable :: kernel_initialiser, bias_initialiser
586 !! Parsed initialiser objects
587 character(256) :: buffer, tag, err_msg
588 !! Input buffer, parsed tag and formatted error message
589 1 real(real32), allocatable, dimension(:) :: data_list
590 !! Temporary storage for flattened parameter blocks
591 integer :: param_line, final_line, num_vals
592 !! Weights-section line markers and current block size
593
594
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(present(verbose)) verbose_ = verbose
595
596 1 key_dim = 0
597 1 iline = 0
598 1 param_line = 0
599 1 final_line = 0
600 25 tag_loop: do
601 26 read(unit,'(A)',iostat=stat) buffer
602
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 26 times.
26 if(stat.ne.0)then
603 write(err_msg,'("file encountered error (EoF?) before END ",A)') &
604 to_upper(this%name)
605 call stop_program(err_msg)
606 return
607 end if
608
2/4
✓ Branch 2 taken 26 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 26 times.
26 if(trim(adjustl(buffer)).eq."") cycle tag_loop
609
610
4/6
✓ Branch 3 taken 26 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 26 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 1 times.
✓ Branch 9 taken 25 times.
52 if(trim(adjustl(buffer)).eq."END "//to_upper(trim(this%name)))then
611 1 final_line = iline
612 1 backspace(unit)
613 26 exit tag_loop
614 end if
615 25 iline = iline + 1
616
617
2/4
✓ Branch 2 taken 25 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 25 times.
✗ Branch 5 not taken.
25 tag=trim(adjustl(buffer))
618
6/10
✓ Branch 0 taken 5 times.
✓ Branch 1 taken 20 times.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✓ Branch 8 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 11 not taken.
25 if(scan(buffer,"=").ne.0) tag=trim(tag(:scan(tag,"=")-1))
619
620 50 select case(trim(tag))
621 case("NUM_INPUTS")
622 2 call assign_val(buffer, num_inputs, itmp1)
623 case("NUM_OUTPUTS")
624 2 call assign_val(buffer, num_outputs, itmp1)
625 case("NUM_BASIS")
626 2 call assign_val(buffer, num_basis, itmp1)
627 case("KEY_DIM")
628 2 call assign_val(buffer, key_dim, itmp1)
629 case("USE_BIAS")
630 2 call assign_val(buffer, use_bias, itmp1)
631 case("ACTIVATION")
632 1 iline = iline - 1
633 1 backspace(unit)
634
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 activation = read_activation(unit, iline)
635 case("KERNEL_INITIALISER", "KERNEL_INIT", "KERNEL_INITIALIZER")
636 call assign_val(buffer, kernel_initialiser_name, itmp1)
637 case("BIAS_INITIALISER", "BIAS_INIT", "BIAS_INITIALIZER")
638 call assign_val(buffer, bias_initialiser_name, itmp1)
639 case("WEIGHTS")
640 1 kernel_initialiser_name = 'zeros'
641 1 bias_initialiser_name = 'zeros'
642 1 param_line = iline
643 case default
644
3/4
✓ Branch 2 taken 18 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 17 times.
✓ Branch 5 taken 1 times.
36 if(scan(to_lower(trim(adjustl(buffer))),&
645 'abcdfghijklmnopqrstuvwxyz').eq.0)then
646 18 cycle tag_loop
647
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 elseif(tag(:3).eq.'END')then
648 18 cycle tag_loop
649 end if
650 write(err_msg,'("Unrecognised line in input file: ",A)') &
651 trim(adjustl(buffer))
652 call stop_program(err_msg)
653
9/12
✓ Branch 0 taken 25 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✓ Branch 5 taken 1 times.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 11 taken 18 times.
50 return
654 end select
655 end do tag_loop
656
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 kernel_initialiser = initialiser_setup(kernel_initialiser_name)
657
5/14
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✗ Branch 3 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
1 bias_initialiser = initialiser_setup(bias_initialiser_name)
658
659
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(key_dim .eq. 0) key_dim = num_basis
660
661 call this%set_hyperparams( &
662 num_outputs = num_outputs, &
663 num_basis = num_basis, &
664 key_dim = key_dim, &
665 use_bias = use_bias, &
666 activation = activation, &
667 kernel_initialiser = kernel_initialiser, &
668 bias_initialiser = bias_initialiser, &
669 verbose = verbose_ &
670 1 )
671
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 call this%init(input_shape=[num_inputs])
672
673
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(param_line.eq.0)then
674 write(0,*) "WARNING: WEIGHTS card in " // trim(this%name) // " not found"
675 else
676 1 call move(unit, param_line - iline, iostat=stat)
677
678 ! Read W_Q
679 1 num_vals = key_dim * num_inputs
680
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_vals), source=0._real32)
681 1 c = 1; k = 1
682
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 4 times.
5 do while(c.le.num_vals)
683 4 read(unit,'(A)',iostat=stat) buffer
684
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if(stat.ne.0) exit
685 4 k = icount(buffer)
686
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 20 times.
24 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
687 4 c = c + k
688 end do
689
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 20 times.
✓ Branch 40 taken 1 times.
21 this%params(1)%val(:,1) = data_list
690
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
691
692 ! Read W_K
693 1 num_vals = key_dim * num_inputs
694
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_vals), source=0._real32)
695 1 c = 1; k = 1
696
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 4 times.
5 do while(c.le.num_vals)
697 4 read(unit,'(A)',iostat=stat) buffer
698
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
4 if(stat.ne.0) exit
699 4 k = icount(buffer)
700
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 24 times.
✓ Branch 3 taken 20 times.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 20 times.
24 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
701 4 c = c + k
702 end do
703
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 20 times.
✓ Branch 40 taken 1 times.
21 this%params(2)%val(:,1) = data_list
704
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
705
706 ! Read W_V
707 1 num_vals = num_outputs * num_inputs
708
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_vals), source=0._real32)
709 1 c = 1; k = 1
710
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 3 times.
4 do while(c.le.num_vals)
711 3 read(unit,'(A)',iostat=stat) buffer
712
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if(stat.ne.0) exit
713 3 k = icount(buffer)
714
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✓ Branch 3 taken 15 times.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 15 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 15 times.
18 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
715 3 c = c + k
716 end do
717
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 15 times.
✓ Branch 40 taken 1 times.
16 this%params(3)%val(:,1) = data_list
718
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
719
720 ! Read B
721 1 num_vals = num_inputs * num_basis
722
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_vals), source=0._real32)
723 1 c = 1; k = 1
724
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
3 do while(c.le.num_vals)
725 2 read(unit,'(A)',iostat=stat) buffer
726
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
2 if(stat.ne.0) exit
727 2 k = icount(buffer)
728
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 12 times.
✓ Branch 3 taken 10 times.
✓ Branch 4 taken 2 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 10 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 10 times.
12 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
729 2 c = c + k
730 end do
731
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 10 times.
✓ Branch 40 taken 1 times.
11 this%params(4)%val(:,1) = data_list
732
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
733
734 ! Read W (bypass)
735 1 num_vals = num_outputs * num_inputs
736
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_vals), source=0._real32)
737 1 c = 1; k = 1
738
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 3 times.
4 do while(c.le.num_vals)
739 3 read(unit,'(A)',iostat=stat) buffer
740
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 3 times.
3 if(stat.ne.0) exit
741 3 k = icount(buffer)
742
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 18 times.
✓ Branch 3 taken 15 times.
✓ Branch 4 taken 3 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 15 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 15 times.
18 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
743 3 c = c + k
744 end do
745
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 15 times.
✓ Branch 40 taken 1 times.
16 this%params(5)%val(:,1) = data_list
746
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
747
748 ! Read b if use_bias
749
1/2
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
1 if(use_bias)then
750
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 allocate(data_list(num_outputs), source=0._real32)
751 1 c = 1; k = 1
752
2/2
✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
2 do while(c.le.num_outputs)
753 1 read(unit,'(A)',iostat=stat) buffer
754
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(stat.ne.0) exit
755 1 k = icount(buffer)
756
5/8
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✓ Branch 3 taken 3 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 3 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 3 times.
4 read(buffer,*,iostat=stat) (data_list(j),j=c,c+k-1)
757 1 c = c + k
758 end do
759
15/28
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 3 times.
✓ Branch 40 taken 1 times.
4 this%params(6)%val(:,1) = data_list(1:num_outputs)
760
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 deallocate(data_list)
761 end if
762
763 1 read(unit,'(A)') buffer
764
2/4
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
1 if(trim(adjustl(buffer)).ne."END WEIGHTS")then
765 call stop_program("END WEIGHTS not where expected")
766 return
767 end if
768 end if
769
770 1 call move(unit, final_line - iline, iostat=stat)
771 1 read(unit,'(A)') buffer
772
3/6
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
2 if(trim(adjustl(buffer)).ne."END "//to_upper(trim(this%name)))then
773 write(err_msg,'("END ",A," not where expected")') to_upper(this%name)
774 call stop_program(err_msg)
775 1 return
776 end if
777
778
7/14
✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 5 not taken.
✓ Branch 6 taken 1 times.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✓ Branch 12 taken 1 times.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
1 end subroutine read_ono_attn
779 !###############################################################################
780
781
782 !###############################################################################
783 1 function read_orthogonal_attention_layer(unit, verbose) result(layer)
784 !! Read an orthogonal attention layer from file and return it
785 implicit none
786
787 ! Arguments
788 integer, intent(in) :: unit
789 !! Input unit number
790 integer, optional, intent(in) :: verbose
791 !! Verbosity level
792 class(base_layer_type), allocatable :: layer
793 !! Allocated base-layer instance containing the result
794
795 ! Local variables
796 integer :: verbose_ = 0
797 !! Effective verbosity level
798
799
1/2
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
1 if(present(verbose)) verbose_ = verbose
800 allocate(layer, source=orthogonal_attention_layer_type( &
801
21/78
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✓ Branch 6 taken 1 times.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✗ Branch 41 not taken.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✗ Branch 45 not taken.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 53 not taken.
✓ Branch 54 taken 1 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 1 times.
✗ Branch 62 not taken.
✓ Branch 63 taken 1 times.
✓ Branch 64 taken 1 times.
✗ Branch 65 not taken.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✓ Branch 69 taken 1 times.
✗ Branch 70 not taken.
✓ Branch 71 taken 1 times.
✗ Branch 72 not taken.
✗ Branch 73 not taken.
✓ Branch 74 taken 1 times.
✓ Branch 76 taken 1 times.
✗ Branch 77 not taken.
✓ Branch 78 taken 1 times.
✗ Branch 79 not taken.
✗ Branch 80 not taken.
✓ Branch 81 taken 1 times.
✓ Branch 83 taken 1 times.
✗ Branch 84 not taken.
2 num_outputs=0, num_basis=1))
802 1 call layer%read(unit, verbose=verbose_)
803
804 2 end function read_orthogonal_attention_layer
805 !###############################################################################
806
807
808 !###############################################################################
809
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 subroutine forward_ono_attn(this, input)
810 !! Forward propagation for the Orthogonal Attention layer
811 !!
812 !! Computes:
813 !! Q = W_Q @ u [k, batch]
814 !! K = W_K @ u [k, batch]
815 !!
816 !! scores = tanh( (Q * K) / sqrt(k) ) [k, batch]
817 !! bounded per-basis interaction scores
818 !!
819 !! attn = softmax(scores, dim=1) [k, batch]
820 !! normalised attention weights across basis modes
821 !!
822 !! spectral = Q(B)^T @ u [k, batch]
823 !! project input to orthogonal spectral basis
824 !!
825 !! modulated = spectral + attn * spectral [k, batch]
826 !! residual spectral modulation
827 !!
828 !! decoded = Q(B) @ modulated [n_in, batch]
829 !! decode modulated spectral representation
830 !!
831 !! attn_out = W_V @ decoded [n_out, batch]
832 !! bypass = W @ u [n_out, batch]
833 !!
834 !! v = sigma( attn_out + bypass + b )
835 implicit none
836
837 ! Arguments
838 class(orthogonal_attention_layer_type), intent(inout) :: this
839 !! Layer instance to execute
840 class(array_type), dimension(:,:), intent(in) :: input
841 !! Input batch tensor collection
842
843 ! Local variables
844 type(array_type), pointer :: ptr, ptr_attn, ptr_bypass
845 !! Combined output, attention-path output and bypass-path output
846 type(array_type), pointer :: ptr_Q, ptr_K, ptr_coeff
847 !! Query, key and per-basis attention coefficient tensors
848 type(array_type), pointer :: ptr_spec, ptr_mod, ptr_decoded
849 !! Spectral encoding, modulated spectrum and decoded tensors
850
851 integer :: n, nb
852 !! Input size and basis count
853 real(real32) :: scale
854 !! Precomputed scaling factor for attention scores
855
856
857 2 n = this%num_inputs
858 2 nb = this%num_basis
859
860
861 !---------------------------------------------------------------------------
862 ! Scaling (critical for stability)
863 !---------------------------------------------------------------------------
864 2 scale = 1.0_real32 / sqrt(real(this%key_dim, kind=real32))
865
866
867 !---------------------------------------------------------------------------
868 ! Attention scores from Q and K projections
869 !---------------------------------------------------------------------------
870
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 ptr_Q => matmul(this%params(1), input(1,1)) ! W_Q @ u: [k, batch]
871
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 ptr_K => matmul(this%params(2), input(1,1)) ! W_K @ u: [k, batch]
872
873
874 !---------------------------------------------------------------------------
875 ! Stable interaction (bounded instead of raw product)
876 !---------------------------------------------------------------------------
877 2 ptr_coeff => ptr_Q * ptr_K * scale ! scaled interaction
878 2 ptr_coeff => tanh(ptr_coeff) ! bound to [-1, 1]
879 2 ptr_coeff => softmax(ptr_coeff, 1) ! [k, batch], sum_k = 1
880
881
882 !---------------------------------------------------------------------------
883 ! Spectral pathway: modulate spectral coefficients by attention scores
884 !---------------------------------------------------------------------------
885
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 ptr_spec => ono_encode(input(1,1), this%params(4), n, nb) ! [k, batch]
886 2 ptr_mod => ptr_coeff * ptr_spec ! [k, batch]
887
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 ptr_decoded => ono_decode(ptr_mod, this%params(4), n, nb) ! [n, batch]
888
889 ! Value projection
890
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 ptr_attn => matmul(this%params(3), ptr_decoded) ! [n_out, batch]
891
892 ! Bypass: W @ u
893
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 ptr_bypass => matmul(this%params(5), input(1,1)) ! [n_out, batch]
894
895 ! Combine: attn_out + bypass
896 2 ptr => ptr_attn + ptr_bypass
897
898 ! Add bias
899
1/2
✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
2 if(this%use_bias)then
900
2/4
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
2 ptr => ptr + this%params(6)
901 end if
902
903 ! Apply activation
904
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 call this%output(1,1)%zero_grad()
905
3/4
✓ Branch 1 taken 2 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
2 if(trim(this%activation%name) .eq. "none")then
906
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 call this%output(1,1)%assign_and_deallocate_source(ptr)
907 else
908 1 call this%z(1)%zero_grad()
909 1 call this%z(1)%assign_and_deallocate_source(ptr)
910 1 this%z(1)%is_temporary = .false.
911 1 ptr => this%activation%apply(this%z(1))
912
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
1 call this%output(1,1)%assign_and_deallocate_source(ptr)
913 end if
914
4/8
✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
2 this%output(1,1)%is_temporary = .false.
915
916 2 end subroutine forward_ono_attn
917 !###############################################################################
918
919
920 !###############################################################################
921 function get_attributes_ono_attn(this) result(attributes)
922 !! Return list of orthogonal attention attributes for ONNX export
923 implicit none
924
925 ! Arguments
926 class(orthogonal_attention_layer_type), intent(in) :: this
927 !! Instance of the orthogonal attention layer
928 type(onnx_attribute_type), allocatable, dimension(:) :: attributes
929 !! List of attributes for ONNX export
930
931 ! Local variables
932 character(32) :: buffer
933 !! Buffer for integer-to-string conversion
934
935 allocate(attributes(6))
936
937 write(buffer, '(I0)') this%num_inputs
938 attributes(1) = onnx_attribute_type( &
939 name='num_inputs', type='int', val=trim(buffer))
940 write(buffer, '(I0)') this%num_outputs
941 attributes(2) = onnx_attribute_type( &
942 name='num_outputs', type='int', val=trim(buffer))
943 write(buffer, '(I0)') this%num_basis
944 attributes(3) = onnx_attribute_type( &
945 name='num_basis', type='int', val=trim(buffer))
946 write(buffer, '(I0)') this%key_dim
947 attributes(4) = onnx_attribute_type( &
948 name='key_dim', type='int', val=trim(buffer))
949 if(this%use_bias)then
950 buffer = '1'
951 else
952 buffer = '0'
953 end if
954 attributes(5) = onnx_attribute_type( &
955 name='use_bias', type='int', val=trim(buffer))
956 attributes(6) = onnx_attribute_type( &
957 name='activation', type='string', val=trim(this%activation%name))
958
959 end function get_attributes_ono_attn
960 !###############################################################################
961
962
42/91
✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 8 times.
✗ Branch 6 not taken.
✗ Branch 7 not taken.
✗ Branch 8 not taken.
✗ Branch 9 not taken.
✗ Branch 10 not taken.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✗ Branch 15 not taken.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✗ Branch 24 not taken.
✗ Branch 25 not taken.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✗ Branch 28 not taken.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✓ Branch 36 taken 6 times.
✓ Branch 37 taken 2 times.
✗ Branch 38 not taken.
✓ Branch 39 taken 6 times.
✓ Branch 40 taken 4 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 6 times.
✓ Branch 43 taken 4 times.
✓ Branch 44 taken 6 times.
✓ Branch 45 taken 10 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 6 times.
✓ Branch 48 taken 6 times.
✓ Branch 49 taken 10 times.
✓ Branch 50 taken 2 times.
✓ Branch 51 taken 10 times.
✓ Branch 52 taken 2 times.
✓ Branch 53 taken 5 times.
✓ Branch 54 taken 3 times.
✓ Branch 55 taken 6 times.
✓ Branch 56 taken 2 times.
✓ Branch 57 taken 35 times.
✓ Branch 58 taken 6 times.
✓ Branch 59 taken 35 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 35 times.
✗ Branch 62 not taken.
✗ Branch 63 not taken.
✓ Branch 64 taken 35 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 35 times.
✗ Branch 67 not taken.
✓ Branch 68 taken 35 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 35 times.
✓ Branch 71 taken 8 times.
✗ Branch 72 not taken.
✓ Branch 74 taken 8 times.
✗ Branch 75 not taken.
✓ Branch 77 taken 8 times.
✗ Branch 78 not taken.
✓ Branch 80 taken 8 times.
✓ Branch 81 taken 8 times.
✗ Branch 82 not taken.
✓ Branch 83 taken 8 times.
✗ Branch 84 not taken.
✓ Branch 85 taken 8 times.
✗ Branch 86 not taken.
✓ Branch 87 taken 8 times.
✗ Branch 88 not taken.
✓ Branch 89 taken 8 times.
✗ Branch 90 not taken.
✓ Branch 91 taken 8 times.
✗ Branch 92 not taken.
✓ Branch 93 taken 8 times.
154 end module athena__orthogonal_attention_layer
963