| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_loss_submodule | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | use coreutils, only: stop_program | ||
| 4 | use diffstruc, only: sign, merge, abs, operator(.le.) | ||
| 5 | |||
| 6 | contains | ||
| 7 | |||
| 8 | !############################################################################### | ||
| 9 | 5 | module function huber_array(delta, gamma) result( output ) | |
| 10 | !! Huber loss function | ||
| 11 | implicit none | ||
| 12 | class(array_type), intent(in), target :: delta | ||
| 13 | real(real32), intent(in) :: gamma | ||
| 14 | type(array_type), pointer :: output | ||
| 15 | |||
| 16 | type(array_type), pointer :: b_array | ||
| 17 | |||
| 18 | 5 | output => delta%create_result() | |
| 19 |
48/68✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
✓ Branch 54 taken 7 times.
✓ Branch 55 taken 5 times.
✓ Branch 56 taken 13 times.
✓ Branch 57 taken 7 times.
✓ Branch 58 taken 5 times.
✗ Branch 59 not taken.
✓ Branch 60 taken 7 times.
✓ Branch 61 taken 5 times.
✓ Branch 62 taken 13 times.
✓ Branch 63 taken 7 times.
✓ Branch 64 taken 11 times.
✓ Branch 65 taken 2 times.
✓ Branch 66 taken 7 times.
✓ Branch 67 taken 5 times.
✓ Branch 68 taken 13 times.
✓ Branch 69 taken 7 times.
✓ Branch 70 taken 11 times.
✓ Branch 71 taken 2 times.
✓ Branch 72 taken 5 times.
✗ Branch 73 not taken.
✓ Branch 74 taken 7 times.
✓ Branch 75 taken 5 times.
✓ Branch 76 taken 13 times.
✓ Branch 77 taken 7 times.
✓ Branch 78 taken 2 times.
✓ Branch 79 taken 11 times.
✓ Branch 80 taken 7 times.
✓ Branch 81 taken 5 times.
✓ Branch 82 taken 13 times.
✓ Branch 83 taken 7 times.
✓ Branch 84 taken 2 times.
✓ Branch 85 taken 11 times.
|
115 | where (abs(delta%val) .le. gamma) |
| 20 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | output%val = 0.5_real32 * (delta%val)**2._real32 |
| 21 | elsewhere | ||
| 22 |
18/36✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
✗ Branch 12 not taken.
✓ Branch 13 taken 5 times.
✗ Branch 15 not taken.
✓ Branch 16 taken 5 times.
✗ Branch 18 not taken.
✓ Branch 19 taken 5 times.
✗ Branch 21 not taken.
✓ Branch 22 taken 5 times.
✗ Branch 24 not taken.
✓ Branch 25 taken 5 times.
✗ Branch 27 not taken.
✓ Branch 28 taken 5 times.
✗ Branch 30 not taken.
✓ Branch 31 taken 5 times.
✗ Branch 33 not taken.
✓ Branch 34 taken 5 times.
✗ Branch 36 not taken.
✓ Branch 37 taken 5 times.
✗ Branch 39 not taken.
✓ Branch 40 taken 5 times.
✗ Branch 42 not taken.
✓ Branch 43 taken 5 times.
✗ Branch 45 not taken.
✓ Branch 46 taken 5 times.
✗ Branch 48 not taken.
✓ Branch 49 taken 5 times.
✗ Branch 51 not taken.
✓ Branch 52 taken 5 times.
|
5 | output%val = gamma * (abs(delta%val) - 0.5_real32 * gamma) |
| 23 | end where | ||
| 24 | |||
| 25 | 5 | output%get_partial_left => get_partial_huber | |
| 26 | 5 | output%get_partial_left_val => get_partial_huber_val | |
| 27 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
|
5 | if(delta%requires_grad)then |
| 28 | ✗ | output%requires_grad = .true. | |
| 29 | ✗ | output%is_forward = delta%is_forward | |
| 30 | ✗ | output%operation = 'huber' | |
| 31 | ✗ | output%left_operand => delta | |
| 32 | ✗ | output%owns_left_operand = delta%is_temporary | |
| 33 | end if | ||
| 34 |
1/2✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
|
5 | allocate(b_array) |
| 35 | 5 | b_array%is_sample_dependent = .false. | |
| 36 | 5 | b_array%is_scalar = .true. | |
| 37 | 5 | b_array%requires_grad = .false. | |
| 38 | 5 | call b_array%allocate(array_shape=[1, 1]) | |
| 39 |
4/8✗ Branch 0 not taken.
✓ Branch 1 taken 5 times.
✗ Branch 3 not taken.
✓ Branch 4 taken 5 times.
✗ Branch 6 not taken.
✓ Branch 7 taken 5 times.
✗ Branch 9 not taken.
✓ Branch 10 taken 5 times.
|
5 | b_array%val(1,1) = gamma |
| 40 | 5 | output%right_operand => b_array | |
| 41 | 5 | output%owns_right_operand = .true. | |
| 42 | |||
| 43 | 5 | end function huber_array | |
| 44 | !------------------------------------------------------------------------------- | ||
| 45 | ✗ | function get_partial_huber(this, upstream_grad) result(output) | |
| 46 | !! Get partial derivative of huber loss | ||
| 47 | implicit none | ||
| 48 | class(array_type), intent(inout) :: this | ||
| 49 | type(array_type), intent(in) :: upstream_grad | ||
| 50 | type(array_type) :: output | ||
| 51 | |||
| 52 | type(array_type), pointer :: ptr | ||
| 53 | |||
| 54 | ptr => merge( & | ||
| 55 | this%left_operand, & | ||
| 56 | ✗ | this%right_operand%val(1,1) * sign(1._real32, this%left_operand), & | |
| 57 | ✗ | abs(this%left_operand) .le. this%right_operand%val(1,1) & | |
| 58 | ✗ | ) | |
| 59 | |||
| 60 | ✗ | call output%assign_and_deallocate_source(ptr) | |
| 61 | ✗ | end function get_partial_huber | |
| 62 | !------------------------------------------------------------------------------- | ||
| 63 | ✗ | pure subroutine get_partial_huber_val(this, upstream_grad, output) | |
| 64 | !! Get partial derivative of huber loss (in-place version) | ||
| 65 | implicit none | ||
| 66 | class(array_type), intent(in) :: this | ||
| 67 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 68 | real(real32), dimension(:,:), intent(out) :: output | ||
| 69 | |||
| 70 | ✗ | where (abs(this%left_operand%val) .le. this%right_operand%val(1,1)) | |
| 71 | ✗ | output = this%left_operand%val | |
| 72 | elsewhere | ||
| 73 | ✗ | output = this%right_operand%val(1,1) * sign(1._real32, this%left_operand%val) | |
| 74 | end where | ||
| 75 | |||
| 76 | ✗ | end subroutine get_partial_huber_val | |
| 77 | !############################################################################### | ||
| 78 | |||
| 79 | end submodule athena__diffstruc_extd_loss_submodule | ||
| 80 |