GCC Code Coverage Report


Directory: src/athena/
File: athena_diffstruc_extd_loss.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 submodule (athena__diffstruc_extd) athena__diffstruc_extd_loss_submodule
2 !! Submodule containing implementations for extended diffstruc array operations
3 use coreutils, only: stop_program
4 use diffstruc, only: sign, merge, abs, operator(.le.)
5
6 contains
7
8 !###############################################################################
9 module function huber_array(delta, gamma) result( output )
10 !! Huber loss function
11 implicit none
12 class(array_type), intent(in), target :: delta
13 real(real32), intent(in) :: gamma
14 type(array_type), pointer :: output
15
16 type(array_type), pointer :: b_array
17
18 output => delta%create_result()
19 where (abs(delta%val) .le. gamma)
20 output%val = 0.5_real32 * (delta%val)**2._real32
21 elsewhere
22 output%val = gamma * (abs(delta%val) - 0.5_real32 * gamma)
23 end where
24
25 output%get_partial_left => get_partial_huber
26 output%get_partial_left_val => get_partial_huber_val
27 if(delta%requires_grad)then
28 output%requires_grad = .true.
29 output%is_forward = delta%is_forward
30 output%operation = 'huber'
31 output%left_operand => delta
32 output%owns_left_operand = delta%is_temporary
33 end if
34 allocate(b_array)
35 b_array%is_sample_dependent = .false.
36 b_array%is_scalar = .true.
37 b_array%requires_grad = .false.
38 call b_array%allocate(array_shape=[1, 1])
39 b_array%val(1,1) = gamma
40 output%right_operand => b_array
41 output%owns_right_operand = .true.
42
43 end function huber_array
44 !-------------------------------------------------------------------------------
45 function get_partial_huber(this, upstream_grad) result(output)
46 !! Get partial derivative of huber loss
47 implicit none
48 class(array_type), intent(inout) :: this
49 type(array_type), intent(in) :: upstream_grad
50 type(array_type) :: output
51
52 type(array_type), pointer :: ptr
53
54 ptr => merge( &
55 this%left_operand, &
56 this%right_operand%val(1,1) * sign(1._real32, this%left_operand), &
57 abs(this%left_operand) .le. this%right_operand%val(1,1) &
58 )
59
60 call output%assign_and_deallocate_source(ptr)
61 end function get_partial_huber
62 !-------------------------------------------------------------------------------
63 pure subroutine get_partial_huber_val(this, upstream_grad, output)
64 !! Get partial derivative of huber loss (in-place version)
65 implicit none
66 class(array_type), intent(in) :: this
67 real(real32), dimension(:,:), intent(in) :: upstream_grad
68 real(real32), dimension(:,:), intent(out) :: output
69
70 where (abs(this%left_operand%val) .le. this%right_operand%val(1,1))
71 output = this%left_operand%val
72 elsewhere
73 output = this%right_operand%val(1,1) * sign(1._real32, this%left_operand%val)
74 end where
75
76 end subroutine get_partial_huber_val
77 !###############################################################################
78
79 end submodule athena__diffstruc_extd_loss_submodule
80