| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_submodule_nop | ||
| 2 | !! Submodule containing autodiff operations for the Graph Neural Operator | ||
| 3 | !! | ||
| 4 | !! Provides two differentiable operations: | ||
| 5 | !! | ||
| 6 | !! 1. `gno_kernel_eval` — evaluates the kernel MLP on each edge: | ||
| 7 | !! \(\kappa(\Delta x) = V \, \mathrm{relu}(U \Delta x + b_u) + b_v\) | ||
| 8 | !! left_operand → edge_features [d, num_edges] | ||
| 9 | !! right_operand → packed kernel params [H*d + H + F*H + F, 1] | ||
| 10 | !! where F = F_out * F_in | ||
| 11 | !! output → [F_out * F_in, num_edges] per-edge kernel values | ||
| 12 | !! | ||
| 13 | !! 2. `gno_aggregate` — aggregates messages using per-edge kernels: | ||
| 14 | !! \(m_i = \sum_{j \in \mathcal{N}(i)} \kappa_{ij} \, h_j\) | ||
| 15 | !! left_operand → features [F_in, num_vertices] | ||
| 16 | !! right_operand → edge_kernels [F_out * F_in, num_edges] | ||
| 17 | !! output → [F_out, num_vertices] | ||
| 18 | !! | ||
| 19 | !! `gno_aggregate` stores `adj_ia` and `adj_ja` on the result for | ||
| 20 | !! use in the backward pass. Metadata (d, H, F_in, F_out) is stored | ||
| 21 | !! in `indices` of the kernel evaluation result. | ||
| 22 | |||
| 23 | contains | ||
| 24 | |||
| 25 | !############################################################################### | ||
| 26 | 5 | module function gno_kernel_eval( & | |
| 27 | coords, kernel_params, adj_ia, adj_ja, & | ||
| 28 | coord_dim, kernel_hidden, F_in, F_out & | ||
| 29 | ) result(c) | ||
| 30 | !! Evaluate the GNO kernel MLP on every directed edge in the graph. | ||
| 31 | !! | ||
| 32 | !! For each edge feature column e, compute: | ||
| 33 | !! dx = edge_features(:,e) [d] | ||
| 34 | !! hidden = relu( U @ dx + b_u ) [H] | ||
| 35 | !! kappa_e = V @ hidden + b_v [F_out*F_in] | ||
| 36 | !! | ||
| 37 | !! Kernel params layout (flat column, size H*d + H + F*H + F): | ||
| 38 | !! U : params(1 : H*d) | ||
| 39 | !! b_u : params(H*d+1 : H*d+H) | ||
| 40 | !! V : params(H*d+H+1 : H*d+H+F*H) | ||
| 41 | !! b_v : params(H*d+H+F*H+1 : end) | ||
| 42 | implicit none | ||
| 43 | |||
| 44 | ! Arguments | ||
| 45 | class(array_type), intent(in), target :: coords | ||
| 46 | !! Edge features / relative coordinates [d, num_edges] | ||
| 47 | class(array_type), intent(in), target :: kernel_params | ||
| 48 | !! Packed kernel parameters [H*d + H + F*H + F, 1] | ||
| 49 | integer, dimension(:), intent(in) :: adj_ia | ||
| 50 | !! CSR row pointers (size num_vertices + 1) | ||
| 51 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 52 | !! CSR column indices (adj_ja(1,:) = neighbour index) | ||
| 53 | integer, intent(in) :: coord_dim, kernel_hidden, F_in, F_out | ||
| 54 | !! Metadata for unpacking kernel_params | ||
| 55 | type(array_type), pointer :: c | ||
| 56 | !! Output per-edge kernel values | ||
| 57 | |||
| 58 | ! Local variables | ||
| 59 | integer :: num_e, d, H, F, e | ||
| 60 | !! Edge count, unpacked dimensions and edge loop index | ||
| 61 | integer :: off_U, off_bu, off_V, off_bv | ||
| 62 | !! Flat offsets for packed kernel parameter blocks | ||
| 63 | 5 | real(real32), allocatable :: U(:,:), b_u(:), V(:,:), b_v(:) | |
| 64 | !! Unpacked kernel parameter tensors | ||
| 65 | 5 | real(real32), allocatable :: dx(:), hidden(:) | |
| 66 | !! Per-edge coordinate and hidden activation buffers | ||
| 67 | |||
| 68 | 5 | d = coord_dim | |
| 69 | 5 | H = kernel_hidden | |
| 70 | 5 | F = F_out * F_in ! kernel output width | |
| 71 | 5 | num_e = size(coords%val, 2) | |
| 72 | |||
| 73 | ! ---- Unpack kernel params ------------------------------------------------ | ||
| 74 | 5 | off_U = 0 | |
| 75 | 5 | off_bu = H * d | |
| 76 | 5 | off_V = off_bu + H | |
| 77 | 5 | off_bv = off_V + F * H | |
| 78 | |||
| 79 |
20/38✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✓ Branch 4 taken 5 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✓ Branch 39 taken 10 times.
✓ Branch 40 taken 5 times.
✓ Branch 42 taken 5 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 5 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 5 times.
|
15 | allocate(U(H, d)); U = reshape(kernel_params%val(off_U+1:off_bu, 1), [H, d]) |
| 80 |
17/34✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 5 times.
✓ Branch 35 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✓ Branch 41 taken 25 times.
✓ Branch 42 taken 5 times.
|
30 | allocate(b_u(H)); b_u = kernel_params%val(off_bu+1:off_V, 1) |
| 81 |
20/38✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✓ Branch 4 taken 5 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✓ Branch 39 taken 10 times.
✓ Branch 40 taken 5 times.
✓ Branch 42 taken 5 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 5 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 5 times.
|
15 | allocate(V(F, H)); V = reshape(kernel_params%val(off_V+1:off_bv, 1), [F, H]) |
| 82 |
17/34✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✗ Branch 17 not taken.
✓ Branch 18 taken 5 times.
✗ Branch 20 not taken.
✓ Branch 21 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 26 not taken.
✓ Branch 27 taken 5 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 5 times.
✗ Branch 32 not taken.
✓ Branch 33 taken 5 times.
✓ Branch 35 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✗ Branch 40 not taken.
✓ Branch 41 taken 60 times.
✓ Branch 42 taken 5 times.
|
65 | allocate(b_v(F)); b_v = kernel_params%val(off_bv+1:, 1) |
| 83 | |||
| 84 | ! ---- Forward: evaluate kernel on every edge ------------------------------ | ||
| 85 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 5 times.
|
15 | c => coords%create_result(array_shape=[F, num_e]) |
| 86 |
14/28✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 5 times.
✓ Branch 17 taken 5 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 5 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 5 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 5 times.
|
5 | allocate(dx(d), hidden(H)) |
| 87 | |||
| 88 |
2/2✓ Branch 0 taken 40 times.
✓ Branch 1 taken 5 times.
|
45 | do e = 1, num_e |
| 89 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✓ Branch 18 taken 40 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 40 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 80 times.
✓ Branch 25 taken 40 times.
|
120 | dx = coords%val(:, e) |
| 90 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✓ Branch 13 taken 40 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 40 times.
✗ Branch 16 not taken.
✗ Branch 17 not taken.
✗ Branch 18 not taken.
✓ Branch 19 taken 200 times.
✓ Branch 20 taken 40 times.
|
240 | hidden = matmul(U, dx) + b_u |
| 91 |
6/10✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✓ Branch 12 taken 200 times.
✓ Branch 13 taken 40 times.
|
240 | hidden = max(hidden, 0.0_real32) ! ReLU |
| 92 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 40 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 40 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 40 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 40 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 40 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 40 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 40 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 40 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 40 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 40 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 40 times.
✓ Branch 34 taken 480 times.
✓ Branch 35 taken 40 times.
|
525 | c%val(:, e) = matmul(V, hidden) + b_v |
| 93 | end do | ||
| 94 | |||
| 95 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
|
5 | deallocate(dx, hidden, U, b_u, V, b_v) |
| 96 | |||
| 97 | ! ---- Store metadata for backward ----------------------------------------- | ||
| 98 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
|
5 | allocate(c%indices(4)) |
| 99 |
4/8✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 20 times.
✓ Branch 7 taken 5 times.
|
30 | c%indices = [d, H, F_in, F_out] |
| 100 | |||
| 101 | 5 | c%get_partial_left => get_partial_gno_kernel_coords | |
| 102 | 5 | c%get_partial_right => get_partial_gno_kernel_params | |
| 103 | 5 | c%get_partial_left_val => get_partial_gno_kernel_coords_val | |
| 104 | 5 | c%get_partial_right_val => get_partial_gno_kernel_params_val | |
| 105 |
1/2✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
|
5 | if(coords%requires_grad .or. kernel_params%requires_grad)then |
| 106 | 5 | c%requires_grad = .true. | |
| 107 | 5 | c%is_forward = coords%is_forward .or. kernel_params%is_forward | |
| 108 | 5 | c%operation = 'gno_kernel_eval' | |
| 109 | 5 | c%left_operand => coords | |
| 110 | 5 | c%right_operand => kernel_params | |
| 111 | 5 | c%owns_left_operand = coords%is_temporary | |
| 112 | 5 | c%owns_right_operand = kernel_params%is_temporary | |
| 113 | end if | ||
| 114 | |||
| 115 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 5 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 5 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 5 times.
|
5 | end function gno_kernel_eval |
| 116 | !------------------------------------------------------------------------------- | ||
| 117 | ✗ | function get_partial_gno_kernel_coords(this, upstream_grad) result(output) | |
| 118 | !! Gradient of gno_kernel_eval w.r.t. edge features (left operand) | ||
| 119 | !! | ||
| 120 | !! upstream_grad has shape [F, num_edges] | ||
| 121 | !! output has shape [d, num_edges] | ||
| 122 | implicit none | ||
| 123 | |||
| 124 | ! Arguments | ||
| 125 | class(array_type), intent(inout) :: this | ||
| 126 | !! Forward result node containing saved operands | ||
| 127 | type(array_type), intent(in) :: upstream_grad | ||
| 128 | !! Upstream gradient tensor | ||
| 129 | type(array_type) :: output | ||
| 130 | !! Gradient tensor for coordinates | ||
| 131 | |||
| 132 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 133 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 134 | |||
| 135 | ✗ | end function get_partial_gno_kernel_coords | |
| 136 | !------------------------------------------------------------------------------- | ||
| 137 | ✗ | pure subroutine get_partial_gno_kernel_coords_val( & | |
| 138 | ✗ | this, upstream_grad, output) | |
| 139 | !! In-place gradient w.r.t. edge features | ||
| 140 | !! | ||
| 141 | !! Chain rule through kernel: | ||
| 142 | !! kappa_e = V @ relu(U @ dx_e + b_u) + b_v | ||
| 143 | !! d(kappa_e)/d(dx_e) = V @ diag(relu'(U dx_e + b_u)) @ U | ||
| 144 | !! Since the left operand already stores edge features directly, | ||
| 145 | !! gradients accumulate independently for each edge column. | ||
| 146 | implicit none | ||
| 147 | |||
| 148 | ! Arguments | ||
| 149 | class(array_type), intent(in) :: this | ||
| 150 | !! Forward result node containing saved operands | ||
| 151 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 152 | !! Upstream gradient values | ||
| 153 | real(real32), dimension(:,:), intent(out) :: output | ||
| 154 | !! Output gradient values for coordinates | ||
| 155 | |||
| 156 | ! Local variables | ||
| 157 | integer :: d, H, F, num_e, e, k | ||
| 158 | !! Unpacked dimensions and loop indices | ||
| 159 | integer :: off_U, off_bu, off_V | ||
| 160 | !! Flat offsets for packed kernel parameter blocks | ||
| 161 | ✗ | real(real32), allocatable :: U(:,:), b_u(:), V(:,:) | |
| 162 | !! Unpacked kernel parameter tensors | ||
| 163 | ✗ | real(real32), allocatable :: dx(:), pre_act(:), relu_mask(:) | |
| 164 | !! Per-edge buffers for input, pre-activation and ReLU mask | ||
| 165 | ✗ | real(real32), allocatable :: dkappa_ddx(:,:) ! [F, d] | |
| 166 | !! Jacobian of edge kernel values with respect to coordinates | ||
| 167 | ✗ | real(real32), allocatable :: grad_dx(:) ! [d] | |
| 168 | !! Coordinate gradient for one edge | ||
| 169 | |||
| 170 | ✗ | d = this%indices(1) | |
| 171 | ✗ | H = this%indices(2) | |
| 172 | ✗ | F = this%indices(3) * this%indices(4) | |
| 173 | ✗ | num_e = size(this%left_operand%val, 2) | |
| 174 | |||
| 175 | ✗ | off_U = 0 | |
| 176 | ✗ | off_bu = H * d | |
| 177 | ✗ | off_V = off_bu + H | |
| 178 | |||
| 179 | ✗ | allocate(U(H, d)) | |
| 180 | ✗ | U = reshape(this%right_operand%val(off_U+1:off_bu, 1), [H, d]) | |
| 181 | ✗ | allocate(b_u(H)) | |
| 182 | ✗ | b_u = this%right_operand%val(off_bu+1:off_V, 1) | |
| 183 | ✗ | allocate(V(F, H)) | |
| 184 | ✗ | V = reshape(this%right_operand%val(off_V+1:off_V+F*H, 1), [F, H]) | |
| 185 | |||
| 186 | ✗ | allocate(dx(d), pre_act(H), relu_mask(H)) | |
| 187 | ✗ | allocate(dkappa_ddx(F, d), grad_dx(d)) | |
| 188 | |||
| 189 | ✗ | output = 0.0_real32 | |
| 190 | |||
| 191 | ✗ | do e = 1, num_e | |
| 192 | ✗ | dx = this%left_operand%val(:, e) | |
| 193 | ✗ | pre_act = matmul(U, dx) + b_u | |
| 194 | ✗ | do k = 1, H | |
| 195 | ✗ | if(pre_act(k) .gt. 0.0_real32)then | |
| 196 | ✗ | relu_mask(k) = 1.0_real32 | |
| 197 | else | ||
| 198 | ✗ | relu_mask(k) = 0.0_real32 | |
| 199 | end if | ||
| 200 | end do | ||
| 201 | |||
| 202 | ✗ | dkappa_ddx = 0.0_real32 | |
| 203 | ✗ | do k = 1, H | |
| 204 | ✗ | if(relu_mask(k) .gt. 0.0_real32)then | |
| 205 | ✗ | dkappa_ddx = dkappa_ddx + & | |
| 206 | ✗ | spread(V(:, k), 2, d) * spread(U(k, :), 1, F) | |
| 207 | end if | ||
| 208 | end do | ||
| 209 | |||
| 210 | ✗ | grad_dx = matmul(upstream_grad(:, e), dkappa_ddx) | |
| 211 | ✗ | output(:, e) = output(:, e) + grad_dx | |
| 212 | end do | ||
| 213 | |||
| 214 | ✗ | deallocate(U, b_u, V, dx, pre_act, relu_mask, dkappa_ddx, grad_dx) | |
| 215 | |||
| 216 | ✗ | end subroutine get_partial_gno_kernel_coords_val | |
| 217 | !------------------------------------------------------------------------------- | ||
| 218 | ✗ | function get_partial_gno_kernel_params(this, upstream_grad) result(output) | |
| 219 | !! Gradient of gno_kernel_eval w.r.t. kernel_params (right operand) | ||
| 220 | implicit none | ||
| 221 | |||
| 222 | ! Arguments | ||
| 223 | class(array_type), intent(inout) :: this | ||
| 224 | !! Forward result node containing saved operands | ||
| 225 | type(array_type), intent(in) :: upstream_grad | ||
| 226 | !! Upstream gradient tensor | ||
| 227 | type(array_type) :: output | ||
| 228 | !! Gradient tensor for packed kernel parameters | ||
| 229 | |||
| 230 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 231 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 232 | |||
| 233 | ✗ | end function get_partial_gno_kernel_params | |
| 234 | !------------------------------------------------------------------------------- | ||
| 235 | ✗ | pure subroutine get_partial_gno_kernel_params_val( & | |
| 236 | ✗ | this, upstream_grad, output) | |
| 237 | !! In-place gradient w.r.t. packed kernel params | ||
| 238 | !! | ||
| 239 | !! Accumulate gradients over all edges: | ||
| 240 | !! d(kappa_e)/dU = V^T @ diag(relu_mask) outer dx | ||
| 241 | !! d(kappa_e)/db_u = V^T @ diag(relu_mask) dot upstream | ||
| 242 | !! d(kappa_e)/dV = upstream outer relu(U dx + b_u) | ||
| 243 | !! d(kappa_e)/db_v = upstream directly | ||
| 244 | implicit none | ||
| 245 | |||
| 246 | ! Arguments | ||
| 247 | class(array_type), intent(in) :: this | ||
| 248 | !! Forward result node containing saved operands | ||
| 249 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 250 | !! Upstream gradient values | ||
| 251 | real(real32), dimension(:,:), intent(out) :: output | ||
| 252 | !! Output gradient values for packed kernel parameters | ||
| 253 | |||
| 254 | ! Local variables | ||
| 255 | integer :: d, H, F, num_e, e, k, f_idx | ||
| 256 | !! Unpacked dimensions and loop indices | ||
| 257 | integer :: off_U, off_bu, off_V, off_bv | ||
| 258 | !! Flat offsets for packed kernel parameter blocks | ||
| 259 | ✗ | real(real32), allocatable :: U(:,:), b_u(:), V(:,:) | |
| 260 | !! Unpacked kernel parameter tensors | ||
| 261 | ✗ | real(real32), allocatable :: dx(:), pre_act(:), hidden(:) | |
| 262 | !! Per-edge buffers for input and activations | ||
| 263 | ✗ | real(real32), allocatable :: grad_hidden(:) ! [H] | |
| 264 | !! Hidden-layer gradient buffer | ||
| 265 | |||
| 266 | ✗ | d = this%indices(1) | |
| 267 | ✗ | H = this%indices(2) | |
| 268 | ✗ | F = this%indices(3) * this%indices(4) | |
| 269 | ✗ | num_e = size(this%left_operand%val, 2) | |
| 270 | |||
| 271 | ✗ | off_U = 0 | |
| 272 | ✗ | off_bu = H * d | |
| 273 | ✗ | off_V = off_bu + H | |
| 274 | ✗ | off_bv = off_V + F * H | |
| 275 | |||
| 276 | ✗ | allocate(U(H, d)) | |
| 277 | ✗ | U = reshape(this%right_operand%val(off_U+1:off_bu, 1), [H, d]) | |
| 278 | ✗ | allocate(b_u(H)) | |
| 279 | ✗ | b_u = this%right_operand%val(off_bu+1:off_V, 1) | |
| 280 | ✗ | allocate(V(F, H)) | |
| 281 | ✗ | V = reshape(this%right_operand%val(off_V+1:off_bv, 1), [F, H]) | |
| 282 | |||
| 283 | ✗ | allocate(dx(d), pre_act(H), hidden(H), grad_hidden(H)) | |
| 284 | |||
| 285 | ✗ | output = 0.0_real32 | |
| 286 | |||
| 287 | ✗ | do e = 1, num_e | |
| 288 | ✗ | dx = this%left_operand%val(:, e) | |
| 289 | ✗ | pre_act = matmul(U, dx) + b_u | |
| 290 | ✗ | hidden = max(pre_act, 0.0_real32) | |
| 291 | |||
| 292 | ! --- d/d(b_v): upstream_grad(:,e) directly --- | ||
| 293 | ✗ | output(off_bv+1:, 1) = output(off_bv+1:, 1) + upstream_grad(:, e) | |
| 294 | |||
| 295 | ! --- d/dV: upstream outer hidden => grad_V(f,h) += upstream(f,e)*hidden(h) --- | ||
| 296 | ✗ | do k = 1, H | |
| 297 | ✗ | do f_idx = 1, F | |
| 298 | ✗ | output(off_V + (k-1)*F + f_idx, 1) = & | |
| 299 | ✗ | output(off_V + (k-1)*F + f_idx, 1) + & | |
| 300 | ✗ | upstream_grad(f_idx, e) * hidden(k) | |
| 301 | end do | ||
| 302 | end do | ||
| 303 | |||
| 304 | ! --- Backprop through relu: grad_hidden = V^T @ upstream(:,e) * relu' --- | ||
| 305 | ✗ | grad_hidden = matmul(transpose(V), upstream_grad(:, e)) | |
| 306 | ✗ | do k = 1, H | |
| 307 | ✗ | if(pre_act(k) .le. 0.0_real32) grad_hidden(k) = 0.0_real32 | |
| 308 | end do | ||
| 309 | |||
| 310 | ! --- d/d(b_u): grad_hidden directly --- | ||
| 311 | ✗ | output(off_bu+1:off_V, 1) = output(off_bu+1:off_V, 1) + grad_hidden | |
| 312 | |||
| 313 | ! --- d/dU: grad_hidden outer dx => grad_U(h,dd) += grad_hidden(h)*dx(dd) --- | ||
| 314 | ✗ | do k = 1, d | |
| 315 | ✗ | do f_idx = 1, H | |
| 316 | ✗ | output(off_U + (k-1)*H + f_idx, 1) = & | |
| 317 | ✗ | output(off_U + (k-1)*H + f_idx, 1) + & | |
| 318 | ✗ | grad_hidden(f_idx) * dx(k) | |
| 319 | end do | ||
| 320 | end do | ||
| 321 | end do | ||
| 322 | |||
| 323 | ✗ | deallocate(U, b_u, V, dx, pre_act, hidden, grad_hidden) | |
| 324 | |||
| 325 | ✗ | end subroutine get_partial_gno_kernel_params_val | |
| 326 | !############################################################################### | ||
| 327 | |||
| 328 | |||
| 329 | !############################################################################### | ||
| 330 | 5 | module function gno_aggregate( & | |
| 331 |
2/4✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 5 times.
✗ Branch 3 not taken.
|
5 | features, edge_kernels, adj_ia, adj_ja, F_in, F_out & |
| 332 | ) result(c) | ||
| 333 | !! Aggregate neighbour messages using pre-computed per-edge kernels. | ||
| 334 | !! | ||
| 335 | !! For each node i: | ||
| 336 | !! m_i = sum_{j in N(i)} reshape(kappa_e, [F_out, F_in]) @ h_j | ||
| 337 | !! | ||
| 338 | !! where e is the edge index corresponding to (i, j). | ||
| 339 | !! | ||
| 340 | !! left_operand → features [F_in, num_vertices] | ||
| 341 | !! right_operand → edge_kernels [F_out*F_in, num_edges] | ||
| 342 | !! output → [F_out, num_vertices] | ||
| 343 | implicit none | ||
| 344 | |||
| 345 | ! Arguments | ||
| 346 | class(array_type), intent(in), target :: features | ||
| 347 | !! Node features [F_in, num_vertices] | ||
| 348 | class(array_type), intent(in), target :: edge_kernels | ||
| 349 | !! Per-edge kernel values [F_out*F_in, num_edges] | ||
| 350 | integer, dimension(:), intent(in) :: adj_ia | ||
| 351 | !! CSR row pointers | ||
| 352 | integer, dimension(:,:), intent(in) :: adj_ja | ||
| 353 | !! CSR column indices | ||
| 354 | integer, intent(in) :: F_in, F_out | ||
| 355 | !! Feature dimensions | ||
| 356 | type(array_type), pointer :: c | ||
| 357 | !! Aggregated node output tensor | ||
| 358 | |||
| 359 | ! Local variables | ||
| 360 | integer :: num_v, i, j, jj, edge_idx | ||
| 361 | !! Node/edge traversal indices | ||
| 362 | |||
| 363 | 5 | num_v = size(features%val, 2) | |
| 364 |
2/2✓ Branch 0 taken 10 times.
✓ Branch 1 taken 5 times.
|
15 | c => features%create_result(array_shape=[F_out, num_v]) |
| 365 |
4/4✓ Branch 0 taken 30 times.
✓ Branch 1 taken 5 times.
✓ Branch 2 taken 90 times.
✓ Branch 3 taken 30 times.
|
125 | c%val = 0.0_real32 |
| 366 | |||
| 367 |
2/2✓ Branch 0 taken 30 times.
✓ Branch 1 taken 5 times.
|
35 | do i = 1, num_v |
| 368 |
6/10✗ Branch 0 not taken.
✓ Branch 1 taken 30 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 30 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 30 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 30 times.
✓ Branch 12 taken 80 times.
✓ Branch 13 taken 30 times.
|
115 | do jj = adj_ia(i), adj_ia(i+1) - 1 |
| 369 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 80 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 80 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 80 times.
|
80 | j = adj_ja(1, jj) |
| 370 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 80 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 80 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 80 times.
|
80 | edge_idx = adj_ja(2, jj) |
| 371 | ! kappa_e reshaped to [F_out, F_in], multiplied by h_j [F_in] | ||
| 372 | ✗ | c%val(:, i) = c%val(:, i) + & | |
| 373 | matmul( & | ||
| 374 | 480 | reshape(edge_kernels%val(:, edge_idx), [F_out, F_in]), & | |
| 375 | 480 | features%val(:, j) & | |
| 376 |
29/54✗ Branch 0 not taken.
✓ Branch 1 taken 80 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 80 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 80 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 80 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 80 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 80 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 80 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 80 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 80 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 80 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 80 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 80 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 80 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 80 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 80 times.
✓ Branch 37 taken 160 times.
✓ Branch 38 taken 80 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 80 times.
✗ Branch 43 not taken.
✓ Branch 44 taken 80 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 80 times.
✗ Branch 49 not taken.
✓ Branch 50 taken 80 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 80 times.
✗ Branch 55 not taken.
✓ Branch 56 taken 80 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 80 times.
✗ Branch 61 not taken.
✓ Branch 62 taken 80 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 80 times.
✗ Branch 65 not taken.
✓ Branch 66 taken 80 times.
✓ Branch 67 taken 240 times.
✓ Branch 68 taken 80 times.
|
510 | ) |
| 377 | end do | ||
| 378 | end do | ||
| 379 | |||
| 380 |
7/14✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 11 not taken.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 35 times.
✓ Branch 16 taken 5 times.
|
40 | c%indices = adj_ia |
| 381 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 20 not taken.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 5 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 80 times.
✓ Branch 27 taken 5 times.
✓ Branch 28 taken 160 times.
✓ Branch 29 taken 80 times.
|
245 | c%adj_ja = adj_ja |
| 382 | |||
| 383 | 5 | c%get_partial_left => get_partial_gno_agg_features | |
| 384 | 5 | c%get_partial_right => get_partial_gno_agg_kernels | |
| 385 | 5 | c%get_partial_left_val => get_partial_gno_agg_features_val | |
| 386 | 5 | c%get_partial_right_val => get_partial_gno_agg_kernels_val | |
| 387 |
1/2✓ Branch 0 taken 5 times.
✗ Branch 1 not taken.
|
5 | if(features%requires_grad .or. edge_kernels%requires_grad)then |
| 388 | 5 | c%requires_grad = .true. | |
| 389 | 5 | c%is_forward = features%is_forward .or. edge_kernels%is_forward | |
| 390 | 5 | c%operation = 'gno_aggregate' | |
| 391 | 5 | c%left_operand => features | |
| 392 | 5 | c%right_operand => edge_kernels | |
| 393 | 5 | c%owns_left_operand = features%is_temporary | |
| 394 | 5 | c%owns_right_operand = edge_kernels%is_temporary | |
| 395 | end if | ||
| 396 | |||
| 397 | 5 | end function gno_aggregate | |
| 398 | !------------------------------------------------------------------------------- | ||
| 399 | ✗ | function get_partial_gno_agg_features(this, upstream_grad) result(output) | |
| 400 | !! Gradient of gno_aggregate w.r.t. features (left operand) | ||
| 401 | !! | ||
| 402 | !! d(m_i)/d(h_j) = kappa_{ij}^T (the [F_in, F_out] transpose) | ||
| 403 | !! So: grad_h(j) += kappa_{ij}^T @ upstream(:,i) | ||
| 404 | implicit none | ||
| 405 | |||
| 406 | ! Arguments | ||
| 407 | class(array_type), intent(inout) :: this | ||
| 408 | !! Forward result node containing saved operands | ||
| 409 | type(array_type), intent(in) :: upstream_grad | ||
| 410 | !! Upstream gradient tensor | ||
| 411 | type(array_type) :: output | ||
| 412 | !! Gradient tensor for node features | ||
| 413 | |||
| 414 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 415 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 416 | |||
| 417 | ✗ | end function get_partial_gno_agg_features | |
| 418 | !------------------------------------------------------------------------------- | ||
| 419 | ✗ | pure subroutine get_partial_gno_agg_features_val( & | |
| 420 | ✗ | this, upstream_grad, output) | |
| 421 | !! In-place gradient w.r.t. features | ||
| 422 | implicit none | ||
| 423 | |||
| 424 | ! Arguments | ||
| 425 | class(array_type), intent(in) :: this | ||
| 426 | !! Forward result node containing saved operands | ||
| 427 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 428 | !! Upstream gradient values | ||
| 429 | real(real32), dimension(:,:), intent(out) :: output | ||
| 430 | !! Output gradient values for node features | ||
| 431 | |||
| 432 | ! Local variables | ||
| 433 | integer :: F_in, F_out, num_v, i, j, jj, edge_idx | ||
| 434 | !! Inferred dimensions and traversal indices | ||
| 435 | |||
| 436 | ! Infer dimensions from operands | ||
| 437 | ✗ | F_in = size(this%left_operand%val, 1) | |
| 438 | ✗ | F_out = size(upstream_grad, 1) | |
| 439 | ✗ | num_v = size(this%left_operand%val, 2) | |
| 440 | |||
| 441 | ✗ | output = 0.0_real32 | |
| 442 | ✗ | do i = 1, num_v | |
| 443 | ✗ | do jj = this%indices(i), this%indices(i+1) - 1 | |
| 444 | ✗ | j = this%adj_ja(1, jj) | |
| 445 | ✗ | edge_idx = this%adj_ja(2, jj) | |
| 446 | ! grad_h(j) += kappa_e^T @ upstream(:,i) | ||
| 447 | ! kappa_e is [F_out*F_in] → reshape to [F_out, F_in] | ||
| 448 | ! kappa_e^T is [F_in, F_out] | ||
| 449 | ✗ | output(:, j) = output(:, j) + & | |
| 450 | matmul( & | ||
| 451 | transpose(reshape( & | ||
| 452 | ✗ | this%right_operand%val(:, edge_idx), [F_out, F_in])), & | |
| 453 | ✗ | upstream_grad(:, i) & | |
| 454 | ✗ | ) | |
| 455 | end do | ||
| 456 | end do | ||
| 457 | |||
| 458 | ✗ | end subroutine get_partial_gno_agg_features_val | |
| 459 | !------------------------------------------------------------------------------- | ||
| 460 | ✗ | function get_partial_gno_agg_kernels(this, upstream_grad) result(output) | |
| 461 | !! Gradient of gno_aggregate w.r.t. edge_kernels (right operand) | ||
| 462 | !! | ||
| 463 | !! d(m_i)/d(kappa_e) = h_j (Kronecker-product structure) | ||
| 464 | !! For vectorised kappa: grad_kappa(e) = upstream(:,i) ⊗ h_j | ||
| 465 | implicit none | ||
| 466 | |||
| 467 | ! Arguments | ||
| 468 | class(array_type), intent(inout) :: this | ||
| 469 | !! Forward result node containing saved operands | ||
| 470 | type(array_type), intent(in) :: upstream_grad | ||
| 471 | !! Upstream gradient tensor | ||
| 472 | type(array_type) :: output | ||
| 473 | !! Gradient tensor for edge kernels | ||
| 474 | |||
| 475 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 476 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 477 | |||
| 478 | ✗ | end function get_partial_gno_agg_kernels | |
| 479 | !------------------------------------------------------------------------------- | ||
| 480 | ✗ | pure subroutine get_partial_gno_agg_kernels_val( & | |
| 481 | ✗ | this, upstream_grad, output) | |
| 482 | !! In-place gradient w.r.t. edge_kernels | ||
| 483 | !! | ||
| 484 | !! The aggregation is: m_i += reshape(kappa_e,[F_out,F_in]) @ h_j | ||
| 485 | !! So d(m_i)/d(kappa_e) viewed as reshape: | ||
| 486 | !! grad_kappa_e = vec( upstream(:,i) @ h_j^T ) | ||
| 487 | implicit none | ||
| 488 | |||
| 489 | ! Arguments | ||
| 490 | class(array_type), intent(in) :: this | ||
| 491 | !! Forward result node containing saved operands | ||
| 492 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 493 | !! Upstream gradient values | ||
| 494 | real(real32), dimension(:,:), intent(out) :: output | ||
| 495 | !! Output gradient values for edge kernels | ||
| 496 | |||
| 497 | ! Local variables | ||
| 498 | integer :: F_in, F_out, num_v, i, j, jj, edge_idx | ||
| 499 | !! Inferred dimensions and traversal indices | ||
| 500 | integer :: fo, fi | ||
| 501 | !! Feature indices for flattened kernel layout | ||
| 502 | |||
| 503 | ! Infer dimensions from operands | ||
| 504 | ✗ | F_in = size(this%left_operand%val, 1) | |
| 505 | ✗ | F_out = size(upstream_grad, 1) | |
| 506 | ✗ | num_v = size(this%left_operand%val, 2) | |
| 507 | |||
| 508 | ✗ | output = 0.0_real32 | |
| 509 | ✗ | do i = 1, num_v | |
| 510 | ✗ | do jj = this%indices(i), this%indices(i+1) - 1 | |
| 511 | ✗ | j = this%adj_ja(1, jj) | |
| 512 | ✗ | edge_idx = this%adj_ja(2, jj) | |
| 513 | ! kappa_e is stored as vec(K) where K = reshape(kappa_e, [F_out, F_in]) | ||
| 514 | ! d(m_i)/d(K(fo,fi)) = upstream(fo, i) * h(fi, j) | ||
| 515 | ! vec index: (fi-1)*F_out + fo | ||
| 516 | ✗ | do fi = 1, F_in | |
| 517 | ✗ | do fo = 1, F_out | |
| 518 | ✗ | output((fi-1)*F_out + fo, edge_idx) = & | |
| 519 | ✗ | output((fi-1)*F_out + fo, edge_idx) + & | |
| 520 | ✗ | upstream_grad(fo, i) * this%left_operand%val(fi, j) | |
| 521 | end do | ||
| 522 | end do | ||
| 523 | end do | ||
| 524 | end do | ||
| 525 | |||
| 526 | ✗ | end subroutine get_partial_gno_agg_kernels_val | |
| 527 | !############################################################################### | ||
| 528 | |||
| 529 | |||
| 530 | !############################################################################### | ||
| 531 | ! Laplace Neural Operator — encode and decode with differentiable poles | ||
| 532 | !############################################################################### | ||
| 533 | |||
| 534 | !############################################################################### | ||
| 535 | 4 | module function lno_encode( & | |
| 536 | input, poles, num_inputs, num_modes & | ||
| 537 | ) result(c) | ||
| 538 | !! Encode input through the Laplace basis built from learnable poles. | ||
| 539 | !! | ||
| 540 | !! Forward: y = E(mu) @ u [M, batch] | ||
| 541 | !! E[m,j] = exp(-mu_m * t_j), t_j = (j-1)/(n_in-1) | ||
| 542 | !! | ||
| 543 | !! left_operand → input u [n_in, batch] | ||
| 544 | !! right_operand → poles mu [M, 1] | ||
| 545 | !! output → encoded [M, batch] | ||
| 546 | implicit none | ||
| 547 | |||
| 548 | ! Arguments | ||
| 549 | class(array_type), intent(in), target :: input | ||
| 550 | !! Input signal tensor [n_in, batch] | ||
| 551 | class(array_type), intent(in), target :: poles | ||
| 552 | !! Learnable poles [M, 1] | ||
| 553 | integer, intent(in) :: num_inputs, num_modes | ||
| 554 | !! Input dimension and number of modes | ||
| 555 | type(array_type), pointer :: c | ||
| 556 | !! Encoded output tensor | ||
| 557 | |||
| 558 | ! Local variables | ||
| 559 | integer :: num_samples, m, j | ||
| 560 | !! Batch and loop indices | ||
| 561 | real(real32) :: t, s | ||
| 562 | !! Normalised coordinate and current pole value | ||
| 563 | 4 | real(real32), allocatable :: E(:,:) ! [M, n_in] | |
| 564 | !! Encoder basis matrix | ||
| 565 | |||
| 566 | 4 | num_samples = size(input%val, 2) | |
| 567 | |||
| 568 | ! Build encoder basis E [M x n_in] | ||
| 569 |
9/18✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
|
4 | allocate(E(num_modes, num_inputs)) |
| 570 |
2/2✓ Branch 0 taken 76 times.
✓ Branch 1 taken 4 times.
|
80 | do j = 1, num_inputs |
| 571 |
1/2✓ Branch 0 taken 76 times.
✗ Branch 1 not taken.
|
76 | if(num_inputs .gt. 1)then |
| 572 | 76 | t = real(j-1, real32) / real(num_inputs-1, real32) | |
| 573 | else | ||
| 574 | ✗ | t = 0.0_real32 | |
| 575 | end if | ||
| 576 |
2/2✓ Branch 0 taken 440 times.
✓ Branch 1 taken 76 times.
|
520 | do m = 1, num_modes |
| 577 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
|
440 | s = poles%val(m, 1) |
| 578 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 440 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 440 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 440 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 440 times.
|
516 | E(m, j) = exp(-s * t) |
| 579 | end do | ||
| 580 | end do | ||
| 581 | |||
| 582 | ! Forward: y = E @ u | ||
| 583 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | c => input%create_result(array_shape=[num_modes, num_samples]) |
| 584 |
12/18✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 56 times.
✓ Branch 6 taken 10 times.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 10 times.
✓ Branch 16 taken 4 times.
✓ Branch 17 taken 56 times.
✓ Branch 18 taken 10 times.
|
136 | c%val = matmul(E, input%val) |
| 585 | |||
| 586 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | deallocate(E) |
| 587 | |||
| 588 | ! Store metadata for backward | ||
| 589 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
|
4 | allocate(c%indices(2)) |
| 590 |
4/8✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✓ Branch 7 taken 4 times.
|
16 | c%indices = [num_inputs, num_modes] |
| 591 | |||
| 592 | 4 | c%get_partial_left => get_partial_lno_encode_input | |
| 593 | 4 | c%get_partial_right => get_partial_lno_encode_poles | |
| 594 | 4 | c%get_partial_left_val => get_partial_lno_encode_input_val | |
| 595 | 4 | c%get_partial_right_val => get_partial_lno_encode_poles_val | |
| 596 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if(input%requires_grad .or. poles%requires_grad)then |
| 597 | 4 | c%requires_grad = .true. | |
| 598 | 4 | c%is_forward = input%is_forward .or. poles%is_forward | |
| 599 | 4 | c%operation = 'lno_encode' | |
| 600 | 4 | c%left_operand => input | |
| 601 | 4 | c%right_operand => poles | |
| 602 | 4 | c%owns_left_operand = input%is_temporary | |
| 603 | 4 | c%owns_right_operand = poles%is_temporary | |
| 604 | end if | ||
| 605 | |||
| 606 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | end function lno_encode |
| 607 | !------------------------------------------------------------------------------- | ||
| 608 | ✗ | function get_partial_lno_encode_input(this, upstream_grad) result(output) | |
| 609 | !! Gradient of lno_encode with respect to input. | ||
| 610 | implicit none | ||
| 611 | |||
| 612 | ! Arguments | ||
| 613 | class(array_type), intent(inout) :: this | ||
| 614 | !! Forward result node containing saved operands | ||
| 615 | type(array_type), intent(in) :: upstream_grad | ||
| 616 | !! Upstream gradient tensor | ||
| 617 | type(array_type) :: output | ||
| 618 | !! Gradient tensor for input | ||
| 619 | |||
| 620 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 621 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 622 | |||
| 623 | ✗ | end function get_partial_lno_encode_input | |
| 624 | !------------------------------------------------------------------------------- | ||
| 625 | ✗ | pure subroutine get_partial_lno_encode_input_val( & | |
| 626 | ✗ | this, upstream_grad, output) | |
| 627 | !! dL/du = E^T @ upstream [n_in, batch] | ||
| 628 | implicit none | ||
| 629 | class(array_type), intent(in) :: this | ||
| 630 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 631 | real(real32), dimension(:,:), intent(out) :: output | ||
| 632 | |||
| 633 | integer :: n_in, num_modes, mode_index, j, s, num_samples | ||
| 634 | real(real32) :: t, mu_m | ||
| 635 | ✗ | real(real32), allocatable :: ET(:,:) ! [n_in, num_modes] | |
| 636 | |||
| 637 | ✗ | n_in = this%indices(1) | |
| 638 | ✗ | num_modes = this%indices(2) | |
| 639 | ✗ | num_samples = size(upstream_grad, 2) | |
| 640 | |||
| 641 | ✗ | allocate(ET(n_in, num_modes)) | |
| 642 | ✗ | do mode_index = 1, num_modes | |
| 643 | ✗ | mu_m = this%right_operand%val(mode_index, 1) | |
| 644 | ✗ | do j = 1, n_in | |
| 645 | ✗ | if(n_in .gt. 1)then | |
| 646 | ✗ | t = real(j-1, real32) / real(n_in-1, real32) | |
| 647 | else | ||
| 648 | ✗ | t = 0.0_real32 | |
| 649 | end if | ||
| 650 | ✗ | ET(j, mode_index) = exp(-mu_m * t) | |
| 651 | end do | ||
| 652 | end do | ||
| 653 | |||
| 654 | ✗ | output = matmul(ET, upstream_grad) | |
| 655 | |||
| 656 | ✗ | deallocate(ET) | |
| 657 | |||
| 658 | ✗ | end subroutine get_partial_lno_encode_input_val | |
| 659 | !------------------------------------------------------------------------------- | ||
| 660 | ✗ | function get_partial_lno_encode_poles(this, upstream_grad) result(output) | |
| 661 | !! Gradient of lno_encode with respect to poles. | ||
| 662 | implicit none | ||
| 663 | |||
| 664 | ! Arguments | ||
| 665 | class(array_type), intent(inout) :: this | ||
| 666 | !! Forward result node containing saved operands | ||
| 667 | type(array_type), intent(in) :: upstream_grad | ||
| 668 | !! Upstream gradient tensor | ||
| 669 | type(array_type) :: output | ||
| 670 | !! Gradient tensor for poles | ||
| 671 | |||
| 672 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 673 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 674 | |||
| 675 | ✗ | end function get_partial_lno_encode_poles | |
| 676 | !------------------------------------------------------------------------------- | ||
| 677 | ✗ | pure subroutine get_partial_lno_encode_poles_val( & | |
| 678 | ✗ | this, upstream_grad, output) | |
| 679 | !! dL/dmu_m per sample: | ||
| 680 | !! output[m,s] = upstream[m,s] * sum_j (-t_j) * exp(-mu_m*t_j) * u[j,s] | ||
| 681 | implicit none | ||
| 682 | class(array_type), intent(in) :: this | ||
| 683 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 684 | real(real32), dimension(:,:), intent(out) :: output | ||
| 685 | |||
| 686 | integer :: n_in, num_modes, mode_index, j, s, num_samples | ||
| 687 | real(real32) :: t, mu_m, accum | ||
| 688 | |||
| 689 | ✗ | n_in = this%indices(1) | |
| 690 | ✗ | num_modes = this%indices(2) | |
| 691 | ✗ | num_samples = size(upstream_grad, 2) | |
| 692 | |||
| 693 | ✗ | output = 0.0_real32 | |
| 694 | ✗ | do s = 1, num_samples | |
| 695 | ✗ | do mode_index = 1, num_modes | |
| 696 | ✗ | mu_m = this%right_operand%val(mode_index, 1) | |
| 697 | ✗ | accum = 0.0_real32 | |
| 698 | ✗ | do j = 1, n_in | |
| 699 | ✗ | if(n_in .gt. 1)then | |
| 700 | ✗ | t = real(j-1, real32) / real(n_in-1, real32) | |
| 701 | else | ||
| 702 | ✗ | t = 0.0_real32 | |
| 703 | end if | ||
| 704 | accum = accum + (-t) * exp(-mu_m * t) * & | ||
| 705 | ✗ | this%left_operand%val(j, s) | |
| 706 | end do | ||
| 707 | ✗ | output(mode_index, s) = upstream_grad(mode_index, s) * accum | |
| 708 | end do | ||
| 709 | end do | ||
| 710 | |||
| 711 | ✗ | end subroutine get_partial_lno_encode_poles_val | |
| 712 | !############################################################################### | ||
| 713 | |||
| 714 | |||
| 715 | !############################################################################### | ||
| 716 | 4 | module function lno_decode( & | |
| 717 | spectral, poles, num_outputs, num_modes & | ||
| 718 | ) result(c) | ||
| 719 | !! Decode through the Laplace basis built from learnable poles. | ||
| 720 | !! | ||
| 721 | !! Forward: y = D(mu) @ x [n_out, batch] | ||
| 722 | !! D[i,m] = exp(-mu_m * tau_i), tau_i = (i-1)/(n_out-1) | ||
| 723 | !! | ||
| 724 | !! left_operand → spectral x [M, batch] | ||
| 725 | !! right_operand → poles mu [M, 1] | ||
| 726 | !! output → decoded [n_out, batch] | ||
| 727 | implicit none | ||
| 728 | |||
| 729 | ! Arguments | ||
| 730 | class(array_type), intent(in), target :: spectral | ||
| 731 | !! Spectral tensor [M, batch] | ||
| 732 | class(array_type), intent(in), target :: poles | ||
| 733 | !! Learnable poles [M, 1] | ||
| 734 | integer, intent(in) :: num_outputs, num_modes | ||
| 735 | !! Output dimension and number of modes | ||
| 736 | type(array_type), pointer :: c | ||
| 737 | !! Decoded output tensor | ||
| 738 | |||
| 739 | ! Local variables | ||
| 740 | integer :: num_samples, m, i | ||
| 741 | !! Batch and loop indices | ||
| 742 | real(real32) :: t, s | ||
| 743 | !! Normalised coordinate and current pole value | ||
| 744 | 4 | real(real32), allocatable :: D(:,:) ! [n_out, M] | |
| 745 | !! Decoder basis matrix | ||
| 746 | |||
| 747 | 4 | num_samples = size(spectral%val, 2) | |
| 748 | |||
| 749 | ! Build decoder basis D [n_out x M] | ||
| 750 |
9/18✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 4 times.
✓ Branch 4 taken 4 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 4 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 4 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 4 times.
|
4 | allocate(D(num_outputs, num_modes)) |
| 751 |
2/2✓ Branch 0 taken 20 times.
✓ Branch 1 taken 4 times.
|
24 | do m = 1, num_modes |
| 752 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 20 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 20 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 20 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 20 times.
|
20 | s = poles%val(m, 1) |
| 753 |
2/2✓ Branch 0 taken 330 times.
✓ Branch 1 taken 20 times.
|
354 | do i = 1, num_outputs |
| 754 |
1/2✓ Branch 0 taken 330 times.
✗ Branch 1 not taken.
|
330 | if(num_outputs .gt. 1)then |
| 755 | 330 | t = real(i-1, real32) / real(num_outputs-1, real32) | |
| 756 | else | ||
| 757 | ✗ | t = 0.0_real32 | |
| 758 | end if | ||
| 759 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 330 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 330 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 330 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 330 times.
|
350 | D(i, m) = exp(-s * t) |
| 760 | end do | ||
| 761 | end do | ||
| 762 | |||
| 763 | ! Forward: y = D @ x | ||
| 764 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | c => spectral%create_result(array_shape=[num_outputs, num_samples]) |
| 765 |
12/18✓ Branch 1 taken 4 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 10 times.
✓ Branch 4 taken 4 times.
✓ Branch 5 taken 165 times.
✓ Branch 6 taken 10 times.
✓ Branch 7 taken 4 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 4 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 4 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 10 times.
✓ Branch 16 taken 4 times.
✓ Branch 17 taken 165 times.
✓ Branch 18 taken 10 times.
|
354 | c%val = matmul(D, spectral%val) |
| 766 | |||
| 767 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | deallocate(D) |
| 768 | |||
| 769 | ! Store metadata for backward | ||
| 770 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 4 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 4 times.
|
4 | allocate(c%indices(2)) |
| 771 |
4/8✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 4 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 8 times.
✓ Branch 7 taken 4 times.
|
16 | c%indices = [num_outputs, num_modes] |
| 772 | |||
| 773 | 4 | c%get_partial_left => get_partial_lno_decode_spectral | |
| 774 | 4 | c%get_partial_right => get_partial_lno_decode_poles | |
| 775 | 4 | c%get_partial_left_val => get_partial_lno_decode_spectral_val | |
| 776 | 4 | c%get_partial_right_val => get_partial_lno_decode_poles_val | |
| 777 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if(spectral%requires_grad .or. poles%requires_grad)then |
| 778 | 4 | c%requires_grad = .true. | |
| 779 | 4 | c%is_forward = spectral%is_forward .or. poles%is_forward | |
| 780 | 4 | c%operation = 'lno_decode' | |
| 781 | 4 | c%left_operand => spectral | |
| 782 | 4 | c%right_operand => poles | |
| 783 | 4 | c%owns_left_operand = spectral%is_temporary | |
| 784 | 4 | c%owns_right_operand = poles%is_temporary | |
| 785 | end if | ||
| 786 | |||
| 787 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 4 times.
|
4 | end function lno_decode |
| 788 | !------------------------------------------------------------------------------- | ||
| 789 | ✗ | function get_partial_lno_decode_spectral(this, upstream_grad) result(output) | |
| 790 | !! Gradient of lno_decode with respect to spectral input. | ||
| 791 | implicit none | ||
| 792 | |||
| 793 | ! Arguments | ||
| 794 | class(array_type), intent(inout) :: this | ||
| 795 | !! Forward result node containing saved operands | ||
| 796 | type(array_type), intent(in) :: upstream_grad | ||
| 797 | !! Upstream gradient tensor | ||
| 798 | type(array_type) :: output | ||
| 799 | !! Gradient tensor for spectral input | ||
| 800 | |||
| 801 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 802 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 803 | |||
| 804 | ✗ | end function get_partial_lno_decode_spectral | |
| 805 | !------------------------------------------------------------------------------- | ||
| 806 | ✗ | pure subroutine get_partial_lno_decode_spectral_val( & | |
| 807 | ✗ | this, upstream_grad, output) | |
| 808 | !! dL/dx = D^T @ upstream [M, batch] | ||
| 809 | implicit none | ||
| 810 | class(array_type), intent(in) :: this | ||
| 811 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 812 | real(real32), dimension(:,:), intent(out) :: output | ||
| 813 | |||
| 814 | integer :: n_out, num_modes, mode_index, i, num_samples | ||
| 815 | real(real32) :: t, mu_m | ||
| 816 | ✗ | real(real32), allocatable :: DT(:,:) ! [num_modes, n_out] | |
| 817 | |||
| 818 | ✗ | n_out = this%indices(1) | |
| 819 | ✗ | num_modes = this%indices(2) | |
| 820 | ✗ | num_samples = size(upstream_grad, 2) | |
| 821 | |||
| 822 | ✗ | allocate(DT(num_modes, n_out)) | |
| 823 | ✗ | do mode_index = 1, num_modes | |
| 824 | ✗ | mu_m = this%right_operand%val(mode_index, 1) | |
| 825 | ✗ | do i = 1, n_out | |
| 826 | ✗ | if(n_out .gt. 1)then | |
| 827 | ✗ | t = real(i-1, real32) / real(n_out-1, real32) | |
| 828 | else | ||
| 829 | ✗ | t = 0.0_real32 | |
| 830 | end if | ||
| 831 | ✗ | DT(mode_index, i) = exp(-mu_m * t) | |
| 832 | end do | ||
| 833 | end do | ||
| 834 | |||
| 835 | ✗ | output = matmul(DT, upstream_grad) | |
| 836 | |||
| 837 | ✗ | deallocate(DT) | |
| 838 | |||
| 839 | ✗ | end subroutine get_partial_lno_decode_spectral_val | |
| 840 | !------------------------------------------------------------------------------- | ||
| 841 | ✗ | function get_partial_lno_decode_poles(this, upstream_grad) result(output) | |
| 842 | !! Gradient of lno_decode with respect to poles. | ||
| 843 | implicit none | ||
| 844 | |||
| 845 | ! Arguments | ||
| 846 | class(array_type), intent(inout) :: this | ||
| 847 | !! Forward result node containing saved operands | ||
| 848 | type(array_type), intent(in) :: upstream_grad | ||
| 849 | !! Upstream gradient tensor | ||
| 850 | type(array_type) :: output | ||
| 851 | !! Gradient tensor for poles | ||
| 852 | |||
| 853 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 854 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 855 | |||
| 856 | ✗ | end function get_partial_lno_decode_poles | |
| 857 | !------------------------------------------------------------------------------- | ||
| 858 | ✗ | pure subroutine get_partial_lno_decode_poles_val( & | |
| 859 | ✗ | this, upstream_grad, output) | |
| 860 | !! dL/dmu_m per sample: | ||
| 861 | !! output[m,s] = sum_i upstream[i,s]*(-tau_i)*exp(-mu_m*tau_i)*x[m,s] | ||
| 862 | implicit none | ||
| 863 | class(array_type), intent(in) :: this | ||
| 864 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 865 | real(real32), dimension(:,:), intent(out) :: output | ||
| 866 | |||
| 867 | integer :: n_out, num_modes, mode_index, i, s, num_samples | ||
| 868 | real(real32) :: t, mu_m, accum | ||
| 869 | |||
| 870 | ✗ | n_out = this%indices(1) | |
| 871 | ✗ | num_modes = this%indices(2) | |
| 872 | ✗ | num_samples = size(upstream_grad, 2) | |
| 873 | |||
| 874 | ✗ | output = 0.0_real32 | |
| 875 | ✗ | do s = 1, num_samples | |
| 876 | ✗ | do mode_index = 1, num_modes | |
| 877 | ✗ | mu_m = this%right_operand%val(mode_index, 1) | |
| 878 | ✗ | accum = 0.0_real32 | |
| 879 | ✗ | do i = 1, n_out | |
| 880 | ✗ | if(n_out .gt. 1)then | |
| 881 | ✗ | t = real(i-1, real32) / real(n_out-1, real32) | |
| 882 | else | ||
| 883 | ✗ | t = 0.0_real32 | |
| 884 | end if | ||
| 885 | ✗ | accum = accum + upstream_grad(i, s) * (-t) * exp(-mu_m * t) | |
| 886 | end do | ||
| 887 | ✗ | output(mode_index, s) = accum * this%left_operand%val(mode_index, s) | |
| 888 | end do | ||
| 889 | end do | ||
| 890 | |||
| 891 | ✗ | end subroutine get_partial_lno_decode_poles_val | |
| 892 | !############################################################################### | ||
| 893 | |||
| 894 | |||
| 895 | !############################################################################### | ||
| 896 | ! Element-wise scale: out[i,s] = input[i,s] * scale[i,1] | ||
| 897 | ! Handles non-sample-dependent scale vectors correctly (unlike built-in *) | ||
| 898 | !############################################################################### | ||
| 899 | |||
| 900 | !############################################################################### | ||
| 901 | 4 | module function elem_scale(input, scale) result(c) | |
| 902 | !! Element-wise scaling with explicit support for sample-independent scale. | ||
| 903 | implicit none | ||
| 904 | |||
| 905 | ! Arguments | ||
| 906 | class(array_type), intent(in), target :: input | ||
| 907 | !! Input tensor [n, batch] | ||
| 908 | class(array_type), intent(in), target :: scale | ||
| 909 | !! Scale tensor [n, 1] | ||
| 910 | type(array_type), pointer :: c | ||
| 911 | !! Scaled output tensor | ||
| 912 | |||
| 913 | ! Local variables | ||
| 914 | integer :: i, s, n, ns | ||
| 915 | !! Feature/sample indices and dimensions | ||
| 916 | |||
| 917 | 4 | n = size(input%val, 1) | |
| 918 | 4 | ns = size(input%val, 2) | |
| 919 | |||
| 920 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
12 | c => input%create_result(array_shape=[n, ns]) |
| 921 | 4 | do concurrent(s = 1:ns, i = 1:n) | |
| 922 |
16/28✓ Branch 0 taken 20 times.
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 56 times.
✓ Branch 3 taken 20 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 56 times.
✗ Branch 7 not taken.
✓ Branch 8 taken 56 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 56 times.
✗ Branch 13 not taken.
✓ Branch 14 taken 56 times.
✗ Branch 16 not taken.
✓ Branch 17 taken 56 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 56 times.
✗ Branch 22 not taken.
✓ Branch 23 taken 56 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 56 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 56 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 56 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 56 times.
✗ Branch 37 not taken.
✓ Branch 38 taken 56 times.
|
80 | c%val(i, s) = input%val(i, s) * scale%val(i, 1) |
| 923 | end do | ||
| 924 | |||
| 925 | 4 | c%get_partial_left => null() | |
| 926 | 4 | c%get_partial_right => null() | |
| 927 | 4 | c%get_partial_left_val => get_partial_elem_scale_input_val | |
| 928 | 4 | c%get_partial_right_val => get_partial_elem_scale_scale_val | |
| 929 |
1/2✓ Branch 0 taken 4 times.
✗ Branch 1 not taken.
|
4 | if(input%requires_grad .or. scale%requires_grad)then |
| 930 | 4 | c%requires_grad = .true. | |
| 931 | 4 | c%is_forward = input%is_forward .or. scale%is_forward | |
| 932 | 4 | c%operation = 'elem_scale' | |
| 933 | 4 | c%left_operand => input | |
| 934 | 4 | c%right_operand => scale | |
| 935 | 4 | c%owns_left_operand = input%is_temporary | |
| 936 | 4 | c%owns_right_operand = scale%is_temporary | |
| 937 | end if | ||
| 938 | |||
| 939 | 4 | end function elem_scale | |
| 940 | !------------------------------------------------------------------------------- | ||
| 941 | |||
| 942 | |||
| 943 | !------------------------------------------------------------------------------- | ||
| 944 | ✗ | pure subroutine get_partial_elem_scale_input_val(this, upstream_grad, output) | |
| 945 | !! d(out)/d(input): upstream * scale (broadcast scale along samples) | ||
| 946 | implicit none | ||
| 947 | |||
| 948 | ! Arguments | ||
| 949 | class(array_type), intent(in) :: this | ||
| 950 | !! Forward result node containing saved operands | ||
| 951 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 952 | !! Upstream gradient values | ||
| 953 | real(real32), dimension(:,:), intent(out) :: output | ||
| 954 | !! Output gradient values for input | ||
| 955 | |||
| 956 | ! Local variables | ||
| 957 | integer :: i, s | ||
| 958 | !! Feature and sample indices | ||
| 959 | |||
| 960 | ✗ | do concurrent(s = 1:size(output,2), i = 1:size(output,1)) | |
| 961 | ✗ | output(i, s) = upstream_grad(i, s) * this%right_operand%val(i, 1) | |
| 962 | end do | ||
| 963 | |||
| 964 | ✗ | end subroutine get_partial_elem_scale_input_val | |
| 965 | !------------------------------------------------------------------------------- | ||
| 966 | |||
| 967 | |||
| 968 | !------------------------------------------------------------------------------- | ||
| 969 | ✗ | pure subroutine get_partial_elem_scale_scale_val(this, upstream_grad, output) | |
| 970 | !! d(out)/d(scale): upstream * input (element-wise, per sample) | ||
| 971 | implicit none | ||
| 972 | |||
| 973 | ! Arguments | ||
| 974 | class(array_type), intent(in) :: this | ||
| 975 | !! Forward result node containing saved operands | ||
| 976 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 977 | !! Upstream gradient values | ||
| 978 | real(real32), dimension(:,:), intent(out) :: output | ||
| 979 | !! Output gradient values for scale tensor | ||
| 980 | |||
| 981 | ! Local variables | ||
| 982 | integer :: i, s | ||
| 983 | !! Feature and sample indices | ||
| 984 | |||
| 985 | ✗ | do concurrent(s = 1:size(output,2), i = 1:size(output,1)) | |
| 986 | ✗ | output(i, s) = upstream_grad(i, s) * this%left_operand%val(i, s) | |
| 987 | end do | ||
| 988 | |||
| 989 | ✗ | end subroutine get_partial_elem_scale_scale_val | |
| 990 | !############################################################################### | ||
| 991 | |||
| 992 | |||
| 993 | !############################################################################### | ||
| 994 | ! Orthogonal Neural Operator — encode and decode with differentiable basis | ||
| 995 | !############################################################################### | ||
| 996 | |||
| 997 | !############################################################################### | ||
| 998 | 7 | module function ono_encode( & | |
| 999 | input, basis_weights, num_inputs, num_basis & | ||
| 1000 | ) result(c) | ||
| 1001 | !! Encode input through an orthogonalised basis. | ||
| 1002 | !! | ||
| 1003 | !! Forward: y = Q(B)^T @ u [k, batch] | ||
| 1004 | !! Q = modified_gram_schmidt(B), B [n x k] from basis_weights | ||
| 1005 | !! | ||
| 1006 | !! left_operand → input u [n, batch] | ||
| 1007 | !! right_operand → basis weights B [n*k, 1] | ||
| 1008 | !! output → encoded [k, batch] | ||
| 1009 | implicit none | ||
| 1010 | |||
| 1011 | ! Arguments | ||
| 1012 | class(array_type), intent(in), target :: input | ||
| 1013 | !! Input tensor [n, batch] | ||
| 1014 | class(array_type), intent(in), target :: basis_weights | ||
| 1015 | !! Flattened basis matrix parameters [n*k, 1] | ||
| 1016 | integer, intent(in) :: num_inputs, num_basis | ||
| 1017 | !! Input dimension and basis size | ||
| 1018 | type(array_type), pointer :: c | ||
| 1019 | !! Encoded output tensor | ||
| 1020 | |||
| 1021 | ! Local variables | ||
| 1022 | integer :: num_samples, n, k, i, j, s | ||
| 1023 | !! Batch/dimension values and loop indices | ||
| 1024 | 7 | real(real32), allocatable :: B(:,:), Q(:,:), QT(:,:) | |
| 1025 | !! Basis matrix, orthonormal basis and transpose buffer | ||
| 1026 | real(real32) :: norm_val, proj | ||
| 1027 | !! Gram-Schmidt norm and projection scalars | ||
| 1028 | |||
| 1029 | 7 | n = num_inputs | |
| 1030 | 7 | k = num_basis | |
| 1031 | 7 | num_samples = size(input%val, 2) | |
| 1032 | |||
| 1033 | ! Modified Gram-Schmidt: B -> Q | ||
| 1034 |
27/54✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✓ Branch 4 taken 7 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 7 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 25 taken 7 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 7 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 7 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 7 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 7 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 7 times.
✓ Branch 42 taken 7 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 7 times.
✓ Branch 46 taken 7 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 7 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 7 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 7 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 7 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 7 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 7 times.
|
7 | allocate(B(n, k), Q(n, k), QT(k, n)) |
| 1035 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 7 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
|
21 | B = reshape(basis_weights%val(:, 1), [n, k]) |
| 1036 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 24 taken 7 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 24 times.
✓ Branch 41 taken 7 times.
✓ Branch 42 taken 320 times.
✓ Branch 43 taken 24 times.
|
351 | Q = B |
| 1037 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 7 times.
|
31 | do j = 1, k |
| 1038 |
2/2✓ Branch 0 taken 33 times.
✓ Branch 1 taken 24 times.
|
57 | do i = 1, j - 1 |
| 1039 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 33 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 33 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 33 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 33 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 33 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 33 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 33 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 33 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 33 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 33 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 33 times.
✓ Branch 39 taken 556 times.
✓ Branch 40 taken 33 times.
|
589 | proj = dot_product(Q(:,i), Q(:,j)) |
| 1040 |
22/42✗ Branch 0 not taken.
✓ Branch 1 taken 33 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 33 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 33 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 33 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 33 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 33 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 33 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 33 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 33 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 33 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 33 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 33 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 33 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 33 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 33 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 33 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 33 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 33 times.
✓ Branch 60 taken 556 times.
✓ Branch 61 taken 33 times.
|
613 | Q(:,j) = Q(:,j) - proj * Q(:,i) |
| 1041 | end do | ||
| 1042 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 24 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 24 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 24 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 24 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 24 times.
✓ Branch 39 taken 320 times.
✓ Branch 40 taken 24 times.
|
344 | norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1043 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
31 | if(norm_val .gt. 1.0e-12_real32)then |
| 1044 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 24 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 24 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 24 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 24 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 24 times.
✓ Branch 39 taken 320 times.
✓ Branch 40 taken 24 times.
|
344 | Q(:,j) = Q(:,j) / norm_val |
| 1045 | else | ||
| 1046 | ✗ | Q(:,j) = 0.0_real32 | |
| 1047 | end if | ||
| 1048 | end do | ||
| 1049 | |||
| 1050 | ! Transpose | ||
| 1051 |
2/2✓ Branch 0 taken 76 times.
✓ Branch 1 taken 7 times.
|
83 | do j = 1, n |
| 1052 |
2/2✓ Branch 0 taken 320 times.
✓ Branch 1 taken 76 times.
|
403 | do i = 1, k |
| 1053 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 320 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 320 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 320 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 320 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 320 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 320 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 320 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 320 times.
|
396 | QT(i, j) = Q(j, i) |
| 1054 | end do | ||
| 1055 | end do | ||
| 1056 | |||
| 1057 | ! Forward: y = Q^T @ u | ||
| 1058 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 7 times.
|
21 | c => input%create_result(array_shape=[k, num_samples]) |
| 1059 |
12/18✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13 times.
✓ Branch 4 taken 7 times.
✓ Branch 5 taken 50 times.
✓ Branch 6 taken 13 times.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 13 times.
✓ Branch 16 taken 7 times.
✓ Branch 17 taken 50 times.
✓ Branch 18 taken 13 times.
|
133 | c%val = matmul(QT, input%val) |
| 1060 | |||
| 1061 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
|
7 | deallocate(B, Q, QT) |
| 1062 | |||
| 1063 | ! Store metadata | ||
| 1064 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
|
7 | allocate(c%indices(2)) |
| 1065 |
4/8✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 14 times.
✓ Branch 7 taken 7 times.
|
28 | c%indices = [n, k] |
| 1066 | |||
| 1067 | 7 | c%get_partial_left => get_partial_ono_encode_input | |
| 1068 | 7 | c%get_partial_right => get_partial_ono_encode_basis | |
| 1069 | 7 | c%get_partial_left_val => get_partial_ono_encode_input_val | |
| 1070 | 7 | c%get_partial_right_val => get_partial_ono_encode_basis_val | |
| 1071 |
1/2✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
|
7 | if(input%requires_grad .or. basis_weights%requires_grad)then |
| 1072 | 7 | c%requires_grad = .true. | |
| 1073 | 7 | c%is_forward = input%is_forward .or. basis_weights%is_forward | |
| 1074 | 7 | c%operation = 'ono_encode' | |
| 1075 | 7 | c%left_operand => input | |
| 1076 | 7 | c%right_operand => basis_weights | |
| 1077 | 7 | c%owns_left_operand = input%is_temporary | |
| 1078 | 7 | c%owns_right_operand = basis_weights%is_temporary | |
| 1079 | end if | ||
| 1080 | |||
| 1081 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 7 times.
|
7 | end function ono_encode |
| 1082 | !------------------------------------------------------------------------------- | ||
| 1083 | ✗ | function get_partial_ono_encode_input(this, upstream_grad) result(output) | |
| 1084 | !! Gradient of ono_encode with respect to input. | ||
| 1085 | implicit none | ||
| 1086 | |||
| 1087 | ! Arguments | ||
| 1088 | class(array_type), intent(inout) :: this | ||
| 1089 | !! Forward result node containing saved operands | ||
| 1090 | type(array_type), intent(in) :: upstream_grad | ||
| 1091 | !! Upstream gradient tensor | ||
| 1092 | type(array_type) :: output | ||
| 1093 | !! Gradient tensor for input | ||
| 1094 | |||
| 1095 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 1096 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 1097 | |||
| 1098 | ✗ | end function get_partial_ono_encode_input | |
| 1099 | !------------------------------------------------------------------------------- | ||
| 1100 | 1 | pure subroutine get_partial_ono_encode_input_val( & | |
| 1101 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | this, upstream_grad, output) |
| 1102 | !! dL/du = Q @ upstream [n, batch] | ||
| 1103 | implicit none | ||
| 1104 | class(array_type), intent(in) :: this | ||
| 1105 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 1106 | real(real32), dimension(:,:), intent(out) :: output | ||
| 1107 | |||
| 1108 | integer :: n, k, i, j | ||
| 1109 | 1 | real(real32), allocatable :: B(:,:), Q(:,:) | |
| 1110 | real(real32) :: norm_val, proj | ||
| 1111 | |||
| 1112 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | n = this%indices(1) |
| 1113 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | k = this%indices(2) |
| 1114 | |||
| 1115 | ! Recompute Q from B | ||
| 1116 |
18/36✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
|
1 | allocate(B(n, k), Q(n, k)) |
| 1117 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
|
3 | B = reshape(this%right_operand%val(:,1), [n, k]) |
| 1118 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | Q = B |
| 1119 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1120 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = 1, j - 1 |
| 1121 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
5 | proj = dot_product(Q(:,i), Q(:,j)) |
| 1122 |
22/42✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✓ Branch 60 taken 4 times.
✓ Branch 61 taken 1 times.
|
7 | Q(:,j) = Q(:,j) - proj * Q(:,i) |
| 1123 | end do | ||
| 1124 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1125 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
3 | if(norm_val .gt. 1.0e-12_real32)then |
| 1126 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | Q(:,j) = Q(:,j) / norm_val |
| 1127 | else | ||
| 1128 | ✗ | Q(:,j) = 0.0_real32 | |
| 1129 | end if | ||
| 1130 | end do | ||
| 1131 | |||
| 1132 | ! dL/du = Q @ upstream | ||
| 1133 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
|
1 | output = matmul(Q, upstream_grad) |
| 1134 | |||
| 1135 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | deallocate(B, Q) |
| 1136 | |||
| 1137 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
|
1 | end subroutine get_partial_ono_encode_input_val |
| 1138 | !------------------------------------------------------------------------------- | ||
| 1139 | ✗ | function get_partial_ono_encode_basis(this, upstream_grad) result(output) | |
| 1140 | !! Gradient of ono_encode with respect to basis weights. | ||
| 1141 | implicit none | ||
| 1142 | |||
| 1143 | ! Arguments | ||
| 1144 | class(array_type), intent(inout) :: this | ||
| 1145 | !! Forward result node containing saved operands | ||
| 1146 | type(array_type), intent(in) :: upstream_grad | ||
| 1147 | !! Upstream gradient tensor | ||
| 1148 | type(array_type) :: output | ||
| 1149 | !! Gradient tensor for basis weights | ||
| 1150 | |||
| 1151 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 1152 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 1153 | |||
| 1154 | ✗ | end function get_partial_ono_encode_basis | |
| 1155 | !------------------------------------------------------------------------------- | ||
| 1156 | 1 | pure subroutine get_partial_ono_encode_basis_val( & | |
| 1157 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | this, upstream_grad, output) |
| 1158 | !! dL/dB per sample through Gram-Schmidt backward. | ||
| 1159 | !! | ||
| 1160 | !! For encode y = Q^T @ u: | ||
| 1161 | !! dL/dQ from sample s: u(:,s) @ upstream(:,s)^T → [n, k] | ||
| 1162 | !! dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k] | ||
| 1163 | implicit none | ||
| 1164 | class(array_type), intent(in) :: this | ||
| 1165 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 1166 | real(real32), dimension(:,:), intent(out) :: output | ||
| 1167 | |||
| 1168 | integer :: n, k, s, i, j, num_samples | ||
| 1169 | 1 | real(real32), allocatable :: B(:,:), Q(:,:), R(:,:) | |
| 1170 | 1 | real(real32), allocatable :: dQ(:,:), dQ_work(:,:), dB(:,:) | |
| 1171 | 1 | real(real32), allocatable :: dv(:), v_recon(:) | |
| 1172 | real(real32) :: norm_j, dprod, dR_ij, proj | ||
| 1173 | |||
| 1174 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | n = this%indices(1) |
| 1175 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | k = this%indices(2) |
| 1176 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | num_samples = size(upstream_grad, 2) |
| 1177 | |||
| 1178 | ! Recompute Q and R from B via modified Gram-Schmidt | ||
| 1179 |
27/54✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
|
1 | allocate(B(n, k), Q(n, k), R(k, k)) |
| 1180 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
|
3 | B = reshape(this%right_operand%val(:,1), [n, k]) |
| 1181 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | Q = B |
| 1182 | 1 | R = 0.0_real32 | |
| 1183 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1184 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = 1, j - 1 |
| 1185 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
|
5 | R(i,j) = dot_product(Q(:,i), Q(:,j)) |
| 1186 |
26/50✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 1 times.
✓ Branch 72 taken 4 times.
✓ Branch 73 taken 1 times.
|
7 | Q(:,j) = Q(:,j) - R(i,j) * Q(:,i) |
| 1187 | end do | ||
| 1188 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 2 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 2 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 2 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 2 times.
|
10 | R(j,j) = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1189 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
|
3 | if(R(j,j) .gt. 1.0e-12_real32)then |
| 1190 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 2 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 2 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 2 times.
✓ Branch 51 taken 8 times.
✓ Branch 52 taken 2 times.
|
10 | Q(:,j) = Q(:,j) / R(j,j) |
| 1191 | else | ||
| 1192 | ✗ | Q(:,j) = 0.0_real32 | |
| 1193 | end if | ||
| 1194 | end do | ||
| 1195 | |||
| 1196 |
27/54✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
|
1 | allocate(dQ(n, k), dQ_work(n, k), dB(n, k)) |
| 1197 |
14/28✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
|
1 | allocate(dv(n), v_recon(n)) |
| 1198 | |||
| 1199 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 8 times.
✓ Branch 21 taken 1 times.
|
10 | output = 0.0_real32 |
| 1200 | |||
| 1201 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | do s = 1, num_samples |
| 1202 | ! dL/dQ for this sample: u(:,s) outer upstream(:,s) | ||
| 1203 | ! dQ[j_n, i_k] = u(j_n, s) * upstream(i_k, s) | ||
| 1204 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1205 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
|
11 | do i = 1, n |
| 1206 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 8 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
|
10 | dQ(i, j) = this%left_operand%val(i, s) * upstream_grad(j, s) |
| 1207 | end do | ||
| 1208 | end do | ||
| 1209 | |||
| 1210 | ! Gram-Schmidt backward: dQ -> dB | ||
| 1211 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | dQ_work = dQ |
| 1212 | 1 | dB = 0.0_real32 | |
| 1213 | |||
| 1214 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = k, 1, -1 |
| 1215 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | norm_j = R(j, j) |
| 1216 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if(norm_j .le. 1.0e-12_real32)then |
| 1217 | ✗ | dB(:,j) = 0.0_real32 | |
| 1218 | ✗ | cycle | |
| 1219 | end if | ||
| 1220 | |||
| 1221 | ! Backward through normalization | ||
| 1222 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | dprod = dot_product(dQ_work(:,j), Q(:,j)) |
| 1223 |
17/34✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 2 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 2 times.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 8 times.
✓ Branch 46 taken 2 times.
|
10 | dv = (dQ_work(:,j) - dprod * Q(:,j)) / norm_j |
| 1224 | |||
| 1225 | ! Reconstruct v before normalization | ||
| 1226 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 8 times.
✓ Branch 25 taken 2 times.
|
10 | v_recon = norm_j * Q(:,j) |
| 1227 | |||
| 1228 | ! Backward through projections (reverse order) | ||
| 1229 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = j-1, 1, -1 |
| 1230 |
19/38✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 4 times.
✓ Branch 52 taken 1 times.
|
5 | v_recon = v_recon + R(i,j) * Q(:,i) |
| 1231 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✓ Branch 33 taken 4 times.
✓ Branch 34 taken 1 times.
|
5 | dR_ij = -dot_product(dv, Q(:,i)) |
| 1232 |
24/46✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✓ Branch 66 taken 4 times.
✓ Branch 67 taken 1 times.
|
5 | dQ_work(:,i) = dQ_work(:,i) - R(i,j) * dv |
| 1233 |
20/38✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✓ Branch 54 taken 4 times.
✓ Branch 55 taken 1 times.
|
5 | dQ_work(:,i) = dQ_work(:,i) + dR_ij * v_recon |
| 1234 |
15/30✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
7 | dv = dv + dR_ij * Q(:,i) |
| 1235 | end do | ||
| 1236 | |||
| 1237 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✓ Branch 33 taken 8 times.
✓ Branch 34 taken 2 times.
|
11 | dB(:,j) = dv |
| 1238 | end do | ||
| 1239 | |||
| 1240 |
7/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✓ Branch 15 taken 1 times.
✓ Branch 16 taken 1 times.
|
3 | output(:, s) = reshape(dB, [n*k]) |
| 1241 | end do | ||
| 1242 | |||
| 1243 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
|
1 | deallocate(B, Q, R, dQ, dQ_work, dB, dv, v_recon) |
| 1244 | |||
| 1245 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | end subroutine get_partial_ono_encode_basis_val |
| 1246 | !############################################################################### | ||
| 1247 | |||
| 1248 | |||
| 1249 | !############################################################################### | ||
| 1250 | 7 | module function ono_decode( & | |
| 1251 | mixed, basis_weights, num_inputs, num_basis & | ||
| 1252 | ) result(c) | ||
| 1253 | !! Decode through an orthogonalised basis. | ||
| 1254 | !! | ||
| 1255 | !! Forward: y = Q(B) @ x [n, batch] | ||
| 1256 | !! Q = modified_gram_schmidt(B), B [n x k] from basis_weights | ||
| 1257 | !! | ||
| 1258 | !! left_operand → mixed x [k, batch] | ||
| 1259 | !! right_operand → basis weights B [n*k, 1] | ||
| 1260 | !! output → decoded [n, batch] | ||
| 1261 | implicit none | ||
| 1262 | |||
| 1263 | ! Arguments | ||
| 1264 | class(array_type), intent(in), target :: mixed | ||
| 1265 | !! Mixed spectral tensor [k, batch] | ||
| 1266 | class(array_type), intent(in), target :: basis_weights | ||
| 1267 | !! Flattened basis matrix parameters [n*k, 1] | ||
| 1268 | integer, intent(in) :: num_inputs, num_basis | ||
| 1269 | !! Output dimension and basis size | ||
| 1270 | type(array_type), pointer :: c | ||
| 1271 | !! Decoded output tensor | ||
| 1272 | |||
| 1273 | ! Local variables | ||
| 1274 | integer :: num_samples, n, k, i, j | ||
| 1275 | !! Batch/dimension values and loop indices | ||
| 1276 | 7 | real(real32), allocatable :: B(:,:), Q(:,:) | |
| 1277 | !! Basis matrix and orthonormal basis | ||
| 1278 | real(real32) :: norm_val, proj | ||
| 1279 | !! Gram-Schmidt norm and projection scalars | ||
| 1280 | |||
| 1281 | 7 | n = num_inputs | |
| 1282 | 7 | k = num_basis | |
| 1283 | 7 | num_samples = size(mixed%val, 2) | |
| 1284 | |||
| 1285 | ! Modified Gram-Schmidt: B -> Q | ||
| 1286 |
18/36✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
✓ Branch 4 taken 7 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 7 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✓ Branch 25 taken 7 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 7 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 7 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 7 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 7 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 7 times.
|
7 | allocate(B(n, k), Q(n, k)) |
| 1287 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✓ Branch 18 taken 14 times.
✓ Branch 19 taken 7 times.
✓ Branch 21 taken 7 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 7 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
|
21 | B = reshape(basis_weights%val(:, 1), [n, k]) |
| 1288 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 7 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 7 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 7 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 7 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 7 times.
✓ Branch 24 taken 7 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 7 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 7 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 24 times.
✓ Branch 41 taken 7 times.
✓ Branch 42 taken 320 times.
✓ Branch 43 taken 24 times.
|
351 | Q = B |
| 1289 |
2/2✓ Branch 0 taken 24 times.
✓ Branch 1 taken 7 times.
|
31 | do j = 1, k |
| 1290 |
2/2✓ Branch 0 taken 33 times.
✓ Branch 1 taken 24 times.
|
57 | do i = 1, j - 1 |
| 1291 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 33 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 33 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 33 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 33 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 33 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 33 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 33 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 33 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 33 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 33 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 33 times.
✓ Branch 39 taken 556 times.
✓ Branch 40 taken 33 times.
|
589 | proj = dot_product(Q(:,i), Q(:,j)) |
| 1292 |
22/42✗ Branch 0 not taken.
✓ Branch 1 taken 33 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 33 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 33 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 33 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 33 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 33 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 33 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 33 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 33 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 33 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 33 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 33 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 33 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 33 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 33 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 33 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 33 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 33 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 33 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 33 times.
✓ Branch 60 taken 556 times.
✓ Branch 61 taken 33 times.
|
613 | Q(:,j) = Q(:,j) - proj * Q(:,i) |
| 1293 | end do | ||
| 1294 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 24 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 24 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 24 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 24 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 24 times.
✓ Branch 39 taken 320 times.
✓ Branch 40 taken 24 times.
|
344 | norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1295 |
1/2✓ Branch 0 taken 24 times.
✗ Branch 1 not taken.
|
31 | if(norm_val .gt. 1.0e-12_real32)then |
| 1296 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 24 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 24 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 24 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 24 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 24 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 24 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 24 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 24 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 24 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 24 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 24 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 24 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 24 times.
✓ Branch 39 taken 320 times.
✓ Branch 40 taken 24 times.
|
344 | Q(:,j) = Q(:,j) / norm_val |
| 1297 | else | ||
| 1298 | ✗ | Q(:,j) = 0.0_real32 | |
| 1299 | end if | ||
| 1300 | end do | ||
| 1301 | |||
| 1302 | ! Forward: y = Q @ x | ||
| 1303 |
2/2✓ Branch 0 taken 14 times.
✓ Branch 1 taken 7 times.
|
21 | c => mixed%create_result(array_shape=[n, num_samples]) |
| 1304 |
12/18✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 13 times.
✓ Branch 4 taken 7 times.
✓ Branch 5 taken 184 times.
✓ Branch 6 taken 13 times.
✓ Branch 7 taken 7 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 7 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 7 times.
✗ Branch 12 not taken.
✗ Branch 13 not taken.
✗ Branch 14 not taken.
✓ Branch 15 taken 13 times.
✓ Branch 16 taken 7 times.
✓ Branch 17 taken 184 times.
✓ Branch 18 taken 13 times.
|
401 | c%val = matmul(Q, mixed%val) |
| 1305 | |||
| 1306 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
|
7 | deallocate(B, Q) |
| 1307 | |||
| 1308 | ! Store metadata | ||
| 1309 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 7 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 7 times.
|
7 | allocate(c%indices(2)) |
| 1310 |
4/8✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 7 times.
✗ Branch 3 not taken.
✗ Branch 4 not taken.
✗ Branch 5 not taken.
✓ Branch 6 taken 14 times.
✓ Branch 7 taken 7 times.
|
28 | c%indices = [n, k] |
| 1311 | |||
| 1312 | 7 | c%get_partial_left => get_partial_ono_decode_mixed | |
| 1313 | 7 | c%get_partial_right => get_partial_ono_decode_basis | |
| 1314 | 7 | c%get_partial_left_val => get_partial_ono_decode_mixed_val | |
| 1315 | 7 | c%get_partial_right_val => get_partial_ono_decode_basis_val | |
| 1316 |
1/2✓ Branch 0 taken 7 times.
✗ Branch 1 not taken.
|
7 | if(mixed%requires_grad .or. basis_weights%requires_grad)then |
| 1317 | 7 | c%requires_grad = .true. | |
| 1318 | 7 | c%is_forward = mixed%is_forward .or. basis_weights%is_forward | |
| 1319 | 7 | c%operation = 'ono_decode' | |
| 1320 | 7 | c%left_operand => mixed | |
| 1321 | 7 | c%right_operand => basis_weights | |
| 1322 | 7 | c%owns_left_operand = mixed%is_temporary | |
| 1323 | 7 | c%owns_right_operand = basis_weights%is_temporary | |
| 1324 | end if | ||
| 1325 | |||
| 1326 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 7 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 7 times.
|
7 | end function ono_decode |
| 1327 | !------------------------------------------------------------------------------- | ||
| 1328 | ✗ | function get_partial_ono_decode_mixed(this, upstream_grad) result(output) | |
| 1329 | !! Gradient of ono_decode with respect to mixed input. | ||
| 1330 | implicit none | ||
| 1331 | |||
| 1332 | ! Arguments | ||
| 1333 | class(array_type), intent(inout) :: this | ||
| 1334 | !! Forward result node containing saved operands | ||
| 1335 | type(array_type), intent(in) :: upstream_grad | ||
| 1336 | !! Upstream gradient tensor | ||
| 1337 | type(array_type) :: output | ||
| 1338 | !! Gradient tensor for mixed input | ||
| 1339 | |||
| 1340 | ✗ | call output%allocate(array_shape=shape(this%left_operand%val)) | |
| 1341 | ✗ | call this%get_partial_left_val(upstream_grad%val, output%val) | |
| 1342 | |||
| 1343 | ✗ | end function get_partial_ono_decode_mixed | |
| 1344 | !------------------------------------------------------------------------------- | ||
| 1345 | 1 | pure subroutine get_partial_ono_decode_mixed_val( & | |
| 1346 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | this, upstream_grad, output) |
| 1347 | !! dL/dx = Q^T @ upstream [k, batch] | ||
| 1348 | implicit none | ||
| 1349 | class(array_type), intent(in) :: this | ||
| 1350 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 1351 | real(real32), dimension(:,:), intent(out) :: output | ||
| 1352 | |||
| 1353 | integer :: n, k, i, j | ||
| 1354 | 1 | real(real32), allocatable :: B(:,:), Q(:,:), QT(:,:) | |
| 1355 | real(real32) :: norm_val, proj | ||
| 1356 | |||
| 1357 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | n = this%indices(1) |
| 1358 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | k = this%indices(2) |
| 1359 | |||
| 1360 | ! Recompute Q from B | ||
| 1361 |
27/54✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
|
1 | allocate(B(n, k), Q(n, k), QT(k, n)) |
| 1362 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
|
3 | B = reshape(this%right_operand%val(:,1), [n, k]) |
| 1363 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | Q = B |
| 1364 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1365 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = 1, j - 1 |
| 1366 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
5 | proj = dot_product(Q(:,i), Q(:,j)) |
| 1367 |
22/42✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✓ Branch 60 taken 4 times.
✓ Branch 61 taken 1 times.
|
7 | Q(:,j) = Q(:,j) - proj * Q(:,i) |
| 1368 | end do | ||
| 1369 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | norm_val = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1370 |
1/2✓ Branch 0 taken 2 times.
✗ Branch 1 not taken.
|
3 | if(norm_val .gt. 1.0e-12_real32)then |
| 1371 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | Q(:,j) = Q(:,j) / norm_val |
| 1372 | else | ||
| 1373 | ✗ | Q(:,j) = 0.0_real32 | |
| 1374 | end if | ||
| 1375 | end do | ||
| 1376 | |||
| 1377 | ! Transpose | ||
| 1378 |
2/2✓ Branch 0 taken 4 times.
✓ Branch 1 taken 1 times.
|
5 | do j = 1, n |
| 1379 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 4 times.
|
13 | do i = 1, k |
| 1380 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
|
12 | QT(i, j) = Q(j, i) |
| 1381 | end do | ||
| 1382 | end do | ||
| 1383 | |||
| 1384 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
|
1 | output = matmul(QT, upstream_grad) |
| 1385 | |||
| 1386 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
|
1 | deallocate(B, Q, QT) |
| 1387 | |||
| 1388 |
3/6✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
|
1 | end subroutine get_partial_ono_decode_mixed_val |
| 1389 | !------------------------------------------------------------------------------- | ||
| 1390 | ✗ | function get_partial_ono_decode_basis(this, upstream_grad) result(output) | |
| 1391 | !! Gradient of ono_decode with respect to basis weights. | ||
| 1392 | implicit none | ||
| 1393 | |||
| 1394 | ! Arguments | ||
| 1395 | class(array_type), intent(inout) :: this | ||
| 1396 | !! Forward result node containing saved operands | ||
| 1397 | type(array_type), intent(in) :: upstream_grad | ||
| 1398 | !! Upstream gradient tensor | ||
| 1399 | type(array_type) :: output | ||
| 1400 | !! Gradient tensor for basis weights | ||
| 1401 | |||
| 1402 | ✗ | call output%allocate(array_shape=shape(this%right_operand%val)) | |
| 1403 | ✗ | call this%get_partial_right_val(upstream_grad%val, output%val) | |
| 1404 | |||
| 1405 | ✗ | end function get_partial_ono_decode_basis | |
| 1406 | !------------------------------------------------------------------------------- | ||
| 1407 | 1 | pure subroutine get_partial_ono_decode_basis_val( & | |
| 1408 |
2/4✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✓ Branch 2 taken 1 times.
✗ Branch 3 not taken.
|
1 | this, upstream_grad, output) |
| 1409 | !! dL/dB per sample through Gram-Schmidt backward. | ||
| 1410 | !! | ||
| 1411 | !! For decode y = Q @ x: | ||
| 1412 | !! dL/dQ from sample s: upstream(:,s) @ x(:,s)^T → [n, k] | ||
| 1413 | !! dL/dB from sample s: gs_backward(B, dL/dQ_s) → [n, k] | ||
| 1414 | implicit none | ||
| 1415 | class(array_type), intent(in) :: this | ||
| 1416 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 1417 | real(real32), dimension(:,:), intent(out) :: output | ||
| 1418 | |||
| 1419 | integer :: n, k, s, i, j, num_samples | ||
| 1420 | 1 | real(real32), allocatable :: B(:,:), Q(:,:), R(:,:) | |
| 1421 | 1 | real(real32), allocatable :: dQ(:,:), dQ_work(:,:), dB(:,:) | |
| 1422 | 1 | real(real32), allocatable :: dv(:), v_recon(:) | |
| 1423 | real(real32) :: norm_j, dprod, dR_ij, proj | ||
| 1424 | |||
| 1425 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | n = this%indices(1) |
| 1426 |
2/4✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
|
1 | k = this%indices(2) |
| 1427 |
6/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
|
1 | num_samples = size(upstream_grad, 2) |
| 1428 | |||
| 1429 | ! Recompute Q and R from B | ||
| 1430 |
27/54✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
|
1 | allocate(B(n, k), Q(n, k), R(k, k)) |
| 1431 |
11/20✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 2 times.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
|
3 | B = reshape(this%right_operand%val(:,1), [n, k]) |
| 1432 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | Q = B |
| 1433 | 1 | R = 0.0_real32 | |
| 1434 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1435 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = 1, j - 1 |
| 1436 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 1 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
|
5 | R(i,j) = dot_product(Q(:,i), Q(:,j)) |
| 1437 |
26/50✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✗ Branch 66 not taken.
✓ Branch 67 taken 1 times.
✗ Branch 69 not taken.
✓ Branch 70 taken 1 times.
✓ Branch 72 taken 4 times.
✓ Branch 73 taken 1 times.
|
7 | Q(:,j) = Q(:,j) - R(i,j) * Q(:,i) |
| 1438 | end do | ||
| 1439 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
✗ Branch 41 not taken.
✓ Branch 42 taken 2 times.
✗ Branch 44 not taken.
✓ Branch 45 taken 2 times.
✗ Branch 47 not taken.
✓ Branch 48 taken 2 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 2 times.
|
10 | R(j,j) = sqrt(dot_product(Q(:,j), Q(:,j))) |
| 1440 |
5/10✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✓ Branch 12 taken 2 times.
✗ Branch 13 not taken.
|
3 | if(R(j,j) .gt. 1.0e-12_real32)then |
| 1441 |
19/36✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 2 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 2 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 2 times.
✓ Branch 51 taken 8 times.
✓ Branch 52 taken 2 times.
|
10 | Q(:,j) = Q(:,j) / R(j,j) |
| 1442 | else | ||
| 1443 | ✗ | Q(:,j) = 0.0_real32 | |
| 1444 | end if | ||
| 1445 | end do | ||
| 1446 | |||
| 1447 |
27/54✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✓ Branch 4 taken 1 times.
✗ Branch 5 not taken.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✓ Branch 21 taken 1 times.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✓ Branch 25 taken 1 times.
✗ Branch 26 not taken.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✓ Branch 30 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✓ Branch 42 taken 1 times.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 1 times.
✓ Branch 46 taken 1 times.
✗ Branch 47 not taken.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 50 not taken.
✓ Branch 51 taken 1 times.
✗ Branch 52 not taken.
✓ Branch 53 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
|
1 | allocate(dQ(n, k), dQ_work(n, k), dB(n, k)) |
| 1448 |
14/28✓ Branch 0 taken 1 times.
✗ Branch 1 not taken.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 11 not taken.
✓ Branch 12 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
✓ Branch 17 taken 1 times.
✗ Branch 18 not taken.
✗ Branch 19 not taken.
✓ Branch 20 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 23 not taken.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 28 not taken.
✓ Branch 29 taken 1 times.
✗ Branch 31 not taken.
✓ Branch 32 taken 1 times.
|
1 | allocate(dv(n), v_recon(n)) |
| 1449 | |||
| 1450 |
10/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✓ Branch 18 taken 1 times.
✓ Branch 19 taken 1 times.
✓ Branch 20 taken 8 times.
✓ Branch 21 taken 1 times.
|
10 | output = 0.0_real32 |
| 1451 | |||
| 1452 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 1 times.
|
2 | do s = 1, num_samples |
| 1453 | ! dL/dQ for this sample: upstream(:,s) outer x(:,s) | ||
| 1454 | ! dQ[i_n, j_k] = upstream(i_n, s) * x(j_k, s) | ||
| 1455 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = 1, k |
| 1456 |
2/2✓ Branch 0 taken 8 times.
✓ Branch 1 taken 2 times.
|
11 | do i = 1, n |
| 1457 |
12/24✗ Branch 0 not taken.
✓ Branch 1 taken 8 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 8 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 8 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 8 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 8 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 8 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 8 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 8 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 8 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 8 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 8 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 8 times.
|
10 | dQ(i, j) = upstream_grad(i, s) * this%left_operand%val(j, s) |
| 1458 | end do | ||
| 1459 | end do | ||
| 1460 | |||
| 1461 | ! Gram-Schmidt backward: dQ -> dB | ||
| 1462 |
15/36✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✓ Branch 24 taken 1 times.
✗ Branch 25 not taken.
✓ Branch 26 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 29 not taken.
✗ Branch 30 not taken.
✗ Branch 31 not taken.
✗ Branch 32 not taken.
✗ Branch 33 not taken.
✗ Branch 34 not taken.
✗ Branch 35 not taken.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✗ Branch 39 not taken.
✓ Branch 40 taken 2 times.
✓ Branch 41 taken 1 times.
✓ Branch 42 taken 8 times.
✓ Branch 43 taken 2 times.
|
11 | dQ_work = dQ |
| 1463 | 1 | dB = 0.0_real32 | |
| 1464 | |||
| 1465 |
2/2✓ Branch 0 taken 2 times.
✓ Branch 1 taken 1 times.
|
3 | do j = k, 1, -1 |
| 1466 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
|
2 | norm_j = R(j, j) |
| 1467 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
|
2 | if(norm_j .le. 1.0e-12_real32)then |
| 1468 | ✗ | dB(:,j) = 0.0_real32 | |
| 1469 | ✗ | cycle | |
| 1470 | end if | ||
| 1471 | |||
| 1472 | ! Backward through normalization | ||
| 1473 |
15/28✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 8 times.
✓ Branch 40 taken 2 times.
|
10 | dprod = dot_product(dQ_work(:,j), Q(:,j)) |
| 1474 |
17/34✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 2 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 2 times.
✓ Branch 39 taken 2 times.
✗ Branch 40 not taken.
✓ Branch 41 taken 2 times.
✗ Branch 42 not taken.
✗ Branch 43 not taken.
✗ Branch 44 not taken.
✓ Branch 45 taken 8 times.
✓ Branch 46 taken 2 times.
|
10 | dv = (dQ_work(:,j) - dprod * Q(:,j)) / norm_j |
| 1475 | |||
| 1476 | ! Reconstruct v before normalization | ||
| 1477 |
10/20✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✓ Branch 18 taken 2 times.
✗ Branch 19 not taken.
✓ Branch 20 taken 2 times.
✗ Branch 21 not taken.
✗ Branch 22 not taken.
✗ Branch 23 not taken.
✓ Branch 24 taken 8 times.
✓ Branch 25 taken 2 times.
|
10 | v_recon = norm_j * Q(:,j) |
| 1478 | |||
| 1479 | ! Backward through projections (reverse order) | ||
| 1480 |
2/2✓ Branch 0 taken 1 times.
✓ Branch 1 taken 2 times.
|
3 | do i = j-1, 1, -1 |
| 1481 |
19/38✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✓ Branch 45 taken 1 times.
✗ Branch 46 not taken.
✓ Branch 47 taken 1 times.
✗ Branch 48 not taken.
✗ Branch 49 not taken.
✗ Branch 50 not taken.
✓ Branch 51 taken 4 times.
✓ Branch 52 taken 1 times.
|
5 | v_recon = v_recon + R(i,j) * Q(:,i) |
| 1482 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✓ Branch 33 taken 4 times.
✓ Branch 34 taken 1 times.
|
5 | dR_ij = -dot_product(dv, Q(:,i)) |
| 1483 |
24/46✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✗ Branch 54 not taken.
✓ Branch 55 taken 1 times.
✗ Branch 57 not taken.
✓ Branch 58 taken 1 times.
✗ Branch 60 not taken.
✓ Branch 61 taken 1 times.
✗ Branch 63 not taken.
✓ Branch 64 taken 1 times.
✓ Branch 66 taken 4 times.
✓ Branch 67 taken 1 times.
|
5 | dQ_work(:,i) = dQ_work(:,i) - R(i,j) * dv |
| 1484 |
20/38✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 1 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 1 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 1 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 1 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 1 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 1 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 1 times.
✓ Branch 54 taken 4 times.
✓ Branch 55 taken 1 times.
|
5 | dQ_work(:,i) = dQ_work(:,i) + dR_ij * v_recon |
| 1485 |
15/30✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 1 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 1 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 1 times.
✓ Branch 33 taken 1 times.
✗ Branch 34 not taken.
✓ Branch 35 taken 1 times.
✗ Branch 36 not taken.
✗ Branch 37 not taken.
✗ Branch 38 not taken.
✓ Branch 39 taken 4 times.
✓ Branch 40 taken 1 times.
|
7 | dv = dv + dR_ij * Q(:,i) |
| 1486 | end do | ||
| 1487 | |||
| 1488 |
13/24✗ Branch 0 not taken.
✓ Branch 1 taken 2 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 2 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 2 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 2 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 2 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 2 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 2 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 2 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 2 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 2 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 2 times.
✓ Branch 33 taken 8 times.
✓ Branch 34 taken 2 times.
|
11 | dB(:,j) = dv |
| 1489 | end do | ||
| 1490 | |||
| 1491 |
7/12✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✓ Branch 15 taken 1 times.
✓ Branch 16 taken 1 times.
|
3 | output(:, s) = reshape(dB, [n*k]) |
| 1492 | end do | ||
| 1493 | |||
| 1494 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 1 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 1 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 1 times.
|
1 | deallocate(B, Q, R, dQ, dQ_work, dB, dv, v_recon) |
| 1495 | |||
| 1496 |
8/16✗ Branch 0 not taken.
✓ Branch 1 taken 1 times.
✗ Branch 2 not taken.
✓ Branch 3 taken 1 times.
✗ Branch 4 not taken.
✓ Branch 5 taken 1 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 1 times.
✗ Branch 8 not taken.
✓ Branch 9 taken 1 times.
✗ Branch 10 not taken.
✓ Branch 11 taken 1 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 1 times.
✗ Branch 14 not taken.
✓ Branch 15 taken 1 times.
|
1 | end subroutine get_partial_ono_decode_basis_val |
| 1497 | !############################################################################### | ||
| 1498 | |||
| 1499 | end submodule athena__diffstruc_extd_submodule_nop | ||
| 1500 |