| Line | Branch | Exec | Source |
|---|---|---|---|
| 1 | module athena__activation_swish | ||
| 2 | !! Module containing implementation of the swish activation function | ||
| 3 | !! | ||
| 4 | !! This module implements Swish (also called SiLU), a smooth, non-monotonic | ||
| 5 | !! activation function discovered by Google researchers. | ||
| 6 | !! | ||
| 7 | !! Mathematical operation: | ||
| 8 | !! \[ f(x) = x \cdot \sigma(\beta x) = \frac{x}{1 + e^{-\beta x}} \] | ||
| 9 | !! | ||
| 10 | !! where \(\beta\) is a parameter (typically \(\beta=1\), making it SiLU) | ||
| 11 | !! | ||
| 12 | !! Derivative: | ||
| 13 | !! \[ f'(x) = \beta f(x) + \sigma(\beta x)(1 - \beta f(x)) \] | ||
| 14 | !! | ||
| 15 | !! Properties: Smooth, self-gated, unbounded above, bounded below at 0 | ||
| 16 | !! Often outperforms ReLU in deep networks | ||
| 17 | !! Reference: Ramachandran et al. (2017), arXiv:1710.05941 | ||
| 18 | use coreutils, only: real32, print_warning | ||
| 19 | use diffstruc, only: array_type, operator(*) | ||
| 20 | use athena__misc_types, only: base_actv_type | ||
| 21 | use athena__misc_types, only: onnx_attribute_type | ||
| 22 | use athena__diffstruc_extd, only: swish | ||
| 23 | implicit none | ||
| 24 | |||
| 25 | private | ||
| 26 | |||
| 27 | public :: swish_actv_type, create_from_onnx_swish_activation | ||
| 28 | |||
| 29 | type, extends(base_actv_type) :: swish_actv_type | ||
| 30 | !! Type for swish activation function with overloaded procedures | ||
| 31 | real(real32) :: beta = 1._real32 | ||
| 32 | !! Beta parameter for swish function | ||
| 33 | contains | ||
| 34 | procedure, pass(this) :: apply => apply_swish | ||
| 35 | procedure, pass(this) :: reset => reset_swish | ||
| 36 | procedure, pass(this) :: apply_attributes => apply_attributes_swish | ||
| 37 | procedure, pass(this) :: export_attributes => export_attributes_swish | ||
| 38 | end type swish_actv_type | ||
| 39 | |||
| 40 | interface swish_actv_type | ||
| 41 | !! Interface for setting up swish activation function | ||
| 42 | procedure initialise | ||
| 43 | end interface swish_actv_type | ||
| 44 | |||
| 45 | contains | ||
| 46 | |||
| 47 | !############################################################################### | ||
| 48 | − | function initialise(scale, beta, attributes) result(activation) | |
| 49 | !! Initialise a swish activation function | ||
| 50 | implicit none | ||
| 51 | |||
| 52 | ! Arguments | ||
| 53 | real(real32), intent(in), optional :: scale | ||
| 54 | !! Optional scale factor for activation output | ||
| 55 | real(real32), intent(in), optional :: beta | ||
| 56 | !! Optional beta parameter for swish function | ||
| 57 | type(onnx_attribute_type), dimension(:), intent(in), optional :: attributes | ||
| 58 | !! Optional array of ONNX attributes | ||
| 59 | type(swish_actv_type) :: activation | ||
| 60 | !! Swish activation type | ||
| 61 | |||
| 62 | |||
| 63 | − | call activation%reset() | |
| 64 | |||
| 65 | − | if(present(scale)) activation%scale = scale | |
| 66 | − | if(abs(activation%scale-1._real32) .gt. 1.e-6_real32)then | |
| 67 | − | activation%apply_scaling = .true. | |
| 68 | end if | ||
| 69 | |||
| 70 | − | if(present(beta)) activation%beta = beta | |
| 71 | |||
| 72 | − | if(present(attributes)) then | |
| 73 | − | call activation%apply_attributes(attributes) | |
| 74 | end if | ||
| 75 | |||
| 76 | − | end function initialise | |
| 77 | !------------------------------------------------------------------------------- | ||
| 78 | − | pure subroutine reset_swish(this) | |
| 79 | !! Reset swish activation function attributes and variables | ||
| 80 | implicit none | ||
| 81 | |||
| 82 | ! Arguments | ||
| 83 | class(swish_actv_type), intent(inout) :: this | ||
| 84 | !! Swish activation type | ||
| 85 | |||
| 86 | − | this%name = "swish" | |
| 87 | − | this%scale = 1._real32 | |
| 88 | − | this%threshold = 0._real32 | |
| 89 | − | this%apply_scaling = .false. | |
| 90 | − | this%beta = 1._real32 | |
| 91 | |||
| 92 | − | end subroutine reset_swish | |
| 93 | !------------------------------------------------------------------------------- | ||
| 94 | − | function create_from_onnx_swish_activation(attributes) result(activation) | |
| 95 | !! Create swish activation function from ONNX attributes | ||
| 96 | implicit none | ||
| 97 | |||
| 98 | ! Arguments | ||
| 99 | type(onnx_attribute_type), dimension(:), intent(in) :: attributes | ||
| 100 | !! Array of ONNX attributes | ||
| 101 | |||
| 102 | class(base_actv_type), allocatable :: activation | ||
| 103 | !! Instance of activation type | ||
| 104 | |||
| 105 | − | allocate(activation, source = swish_actv_type(attributes = attributes)) | |
| 106 | |||
| 107 | − | end function create_from_onnx_swish_activation | |
| 108 | !############################################################################### | ||
| 109 | |||
| 110 | |||
| 111 | !############################################################################### | ||
| 112 | − | subroutine apply_attributes_swish(this, attributes) | |
| 113 | !! Load ONNX attributes into swish activation function | ||
| 114 | implicit none | ||
| 115 | |||
| 116 | ! Arguments | ||
| 117 | class(swish_actv_type), intent(inout) :: this | ||
| 118 | !! Swish activation type | ||
| 119 | type(onnx_attribute_type), dimension(:), intent(in) :: attributes | ||
| 120 | !! Array of ONNX attributes | ||
| 121 | |||
| 122 | ! Local variables | ||
| 123 | integer :: i | ||
| 124 | !! Loop variable | ||
| 125 | |||
| 126 | ! Load provided attributes | ||
| 127 | − | do i=1, size(attributes,dim=1) | |
| 128 | − | select case(trim(attributes(i)%name)) | |
| 129 | case("scale") | ||
| 130 | − | read(attributes(i)%val,*) this%scale | |
| 131 | − | if(abs(this%scale-1._real32) .gt. 1.e-6_real32)then | |
| 132 | − | this%apply_scaling = .true. | |
| 133 | else | ||
| 134 | − | this%apply_scaling = .false. | |
| 135 | end if | ||
| 136 | case("beta") | ||
| 137 | − | read(attributes(i)%val,*) this%beta | |
| 138 | case("name") | ||
| 139 | − | if(trim(attributes(i)%val) .ne. trim(this%name)) then | |
| 140 | call print_warning( & | ||
| 141 | 'Swish activation: name attribute "' // & | ||
| 142 | − | trim(attributes(i)%val) // & | |
| 143 | '"" does not match expected "' // trim(this%name)//'"' & | ||
| 144 | − | ) | |
| 145 | |||
| 146 | end if | ||
| 147 | case default | ||
| 148 | call print_warning( & | ||
| 149 | 'Swish activation: unknown attribute '// & | ||
| 150 | − | trim(attributes(i)%name) & | |
| 151 | − | ) | |
| 152 | end select | ||
| 153 | end do | ||
| 154 | |||
| 155 | − | end subroutine apply_attributes_swish | |
| 156 | !############################################################################### | ||
| 157 | |||
| 158 | |||
| 159 | !############################################################################### | ||
| 160 | − | pure function export_attributes_swish(this) result(attributes) | |
| 161 | !! Export swish activation function attributes as ONNX attributes | ||
| 162 | implicit none | ||
| 163 | |||
| 164 | ! Arguments | ||
| 165 | class(swish_actv_type), intent(in) :: this | ||
| 166 | !! Swish activation type | ||
| 167 | type(onnx_attribute_type), allocatable, dimension(:) :: attributes | ||
| 168 | !! Array of ONNX attributes | ||
| 169 | |||
| 170 | ! Local variables | ||
| 171 | character(50) :: buffer | ||
| 172 | !! Temporary string buffer | ||
| 173 | |||
| 174 | − | allocate(attributes(3)) | |
| 175 | |||
| 176 | − | write(buffer, '(A)') this%name | |
| 177 | − | attributes(1) = onnx_attribute_type( & | |
| 178 | − | "name", "string", trim(adjustl(buffer)) ) | |
| 179 | |||
| 180 | − | write(buffer, '(F10.6)') this%scale | |
| 181 | − | attributes(2) = onnx_attribute_type( & | |
| 182 | − | "scale", "float", trim(adjustl(buffer)) ) | |
| 183 | |||
| 184 | − | write(buffer, '(F10.6)') this%beta | |
| 185 | − | attributes(3) = onnx_attribute_type( & | |
| 186 | − | "beta", "float", trim(adjustl(buffer)) ) | |
| 187 | |||
| 188 | − | end function export_attributes_swish | |
| 189 | !############################################################################### | ||
| 190 | |||
| 191 | |||
| 192 | !############################################################################### | ||
| 193 | − | function apply_swish(this, val) result(output) | |
| 194 | !! Apply swish activation to 1D array | ||
| 195 | !! | ||
| 196 | !! Computes: f(x) = x * sigmoid(β*x) = x / (1 + exp(-β*x)) | ||
| 197 | implicit none | ||
| 198 | |||
| 199 | ! Arguments | ||
| 200 | class(swish_actv_type), intent(in) :: this | ||
| 201 | !! Swish activation type | ||
| 202 | type(array_type), intent(in) :: val | ||
| 203 | !! Input values | ||
| 204 | type(array_type), pointer :: output | ||
| 205 | !! Swish activation output | ||
| 206 | |||
| 207 | ! Compute sigmoid(β*x) | ||
| 208 | ! Compute swish: x * sigmoid(β*x) | ||
| 209 | − | if(this%apply_scaling)then | |
| 210 | − | output => swish(val, this%beta) * this%scale | |
| 211 | else | ||
| 212 | − | output => swish(val, this%beta) | |
| 213 | end if | ||
| 214 | − | end function apply_swish | |
| 215 | !############################################################################### | ||
| 216 | |||
| 217 | − | end module athena__activation_swish | |
| 218 |