| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | submodule (athena__diffstruc_extd) athena__diffstruc_extd_loss_submodule | ||
| 2 | !! Submodule containing implementations for extended diffstruc array operations | ||
| 3 | use coreutils, only: stop_program | ||
| 4 | use diffstruc, only: sign, merge, abs, operator(.le.) | ||
| 5 | |||
| 6 | contains | ||
| 7 | |||
| 8 | !############################################################################### | ||
| 9 | − | module function huber_array(delta, gamma) result( output ) | |
| 10 | !! Huber loss function | ||
| 11 | implicit none | ||
| 12 | class(array_type), intent(in), target :: delta | ||
| 13 | real(real32), intent(in) :: gamma | ||
| 14 | type(array_type), pointer :: output | ||
| 15 | |||
| 16 | type(array_type), pointer :: b_array | ||
| 17 | |||
| 18 | − | output => delta%create_result() | |
| 19 | − | where (abs(delta%val) .le. gamma) | |
| 20 | − | output%val = 0.5_real32 * (delta%val)**2._real32 | |
| 21 | elsewhere | ||
| 22 | − | output%val = gamma * (abs(delta%val) - 0.5_real32 * gamma) | |
| 23 | end where | ||
| 24 | |||
| 25 | − | output%get_partial_left => get_partial_huber | |
| 26 | − | output%get_partial_left_val => get_partial_huber_val | |
| 27 | − | if(delta%requires_grad)then | |
| 28 | − | output%requires_grad = .true. | |
| 29 | − | output%is_forward = delta%is_forward | |
| 30 | − | output%operation = 'huber' | |
| 31 | − | output%left_operand => delta | |
| 32 | − | output%owns_left_operand = delta%is_temporary | |
| 33 | end if | ||
| 34 | − | allocate(b_array) | |
| 35 | − | b_array%is_sample_dependent = .false. | |
| 36 | − | b_array%is_scalar = .true. | |
| 37 | − | b_array%requires_grad = .false. | |
| 38 | − | call b_array%allocate(array_shape=[1, 1]) | |
| 39 | − | b_array%val(1,1) = gamma | |
| 40 | − | output%right_operand => b_array | |
| 41 | − | output%owns_right_operand = .true. | |
| 42 | |||
| 43 | − | end function huber_array | |
| 44 | !------------------------------------------------------------------------------- | ||
| 45 | − | function get_partial_huber(this, upstream_grad) result(output) | |
| 46 | !! Get partial derivative of huber loss | ||
| 47 | implicit none | ||
| 48 | class(array_type), intent(inout) :: this | ||
| 49 | type(array_type), intent(in) :: upstream_grad | ||
| 50 | type(array_type) :: output | ||
| 51 | |||
| 52 | type(array_type), pointer :: ptr | ||
| 53 | |||
| 54 | ptr => merge( & | ||
| 55 | this%left_operand, & | ||
| 56 | − | this%right_operand%val(1,1) * sign(1._real32, this%left_operand), & | |
| 57 | − | abs(this%left_operand) .le. this%right_operand%val(1,1) & | |
| 58 | − | ) | |
| 59 | |||
| 60 | − | call output%assign_and_deallocate_source(ptr) | |
| 61 | − | end function get_partial_huber | |
| 62 | !------------------------------------------------------------------------------- | ||
| 63 | − | pure subroutine get_partial_huber_val(this, upstream_grad, output) | |
| 64 | !! Get partial derivative of huber loss (in-place version) | ||
| 65 | implicit none | ||
| 66 | class(array_type), intent(in) :: this | ||
| 67 | real(real32), dimension(:,:), intent(in) :: upstream_grad | ||
| 68 | real(real32), dimension(:,:), intent(out) :: output | ||
| 69 | |||
| 70 | − | where (abs(this%left_operand%val) .le. this%right_operand%val(1,1)) | |
| 71 | − | output = this%left_operand%val | |
| 72 | elsewhere | ||
| 73 | − | output = this%right_operand%val(1,1) * sign(1._real32, this%left_operand%val) | |
| 74 | end where | ||
| 75 | |||
| 76 | − | end subroutine get_partial_huber_val | |
| 77 | !############################################################################### | ||
| 78 | |||
| 79 | end submodule athena__diffstruc_extd_loss_submodule | ||
| 80 |