GCC Code Coverage Report


Directory: src/athena/
File: athena_loss.f90
Date: 2025-12-10 07:37:07
Exec Total Coverage
Lines: 0 0 100.0%
Functions: 0 0 -%
Branches: 0 0 -%

Line Branch Exec Source
1 module athena__loss
2 !! Module containing loss function implementations
3 !!
4 !! This module implements loss functions that quantify the difference between
5 !! model predictions and target values, guiding the optimisation process.
6 !!
7 !! Implemented loss functions:
8 !!
9 !! Mean Squared Error (MSE):
10 !! L = (1/N) Σ (y_pred - y_true)²
11 !! For regression, sensitive to outliers
12 !!
13 !! Mean Absolute Error (MAE):
14 !! L = (1/N) Σ |y_pred - y_true|
15 !! For regression, robust to outliers
16 !!
17 !! Binary Cross-Entropy:
18 !! L = -(1/N) Σ [y*log(ŷ) + (1-y)*log(1-ŷ)]
19 !! For binary classification (outputs in [0,1])
20 !!
21 !! Categorical Cross-Entropy:
22 !! L = -(1/N) Σ_i Σ_c y_{i,c} * log(ŷ_{i,c})
23 !! For multi-class classification with one-hot encoded targets
24 !!
25 !! Sparse Categorical Cross-Entropy:
26 !! L = -(1/N) Σ log(ŷ_{i,c_i})
27 !! For multi-class with integer class labels
28 !!
29 !! Huber Loss:
30 !! L = (1/N) Σ { 0.5*(y-ŷ)² if |y-ŷ| ≤ δ
31 !! { δ*(|y-ŷ| - 0.5*δ) otherwise
32 !! Combines MSE and MAE, robust to outliers while smooth near zero
33 !!
34 !! where N is number of samples, y is true value, ŷ is prediction
35 use coreutils, only: real32
36 use diffstruc, only: array_type, operator(+), operator(-), &
37 operator(*), operator(/), operator(**), mean, sum, log, abs, merge
38 use athena__diffstruc_extd, only: huber
39 implicit none
40
41
42 private
43
44 public :: base_loss_type
45 public :: bce_loss_type
46 public :: cce_loss_type
47 public :: mae_loss_type
48 public :: mse_loss_type
49 public :: nll_loss_type
50 public :: huber_loss_type
51
52
53 type, abstract :: base_loss_type
54 !! Abstract type for loss functions
55 character(len=:), allocatable :: name
56 !! Name of the loss function
57 real(real32) :: epsilon = 1.E-10_real32
58 !! Small value to prevent log(0)
59 integer :: batch_index = 1
60 !! Index of the batch to compute the loss for
61 integer :: sample_index = 1
62 !! Index of the sample to compute the loss for
63 contains
64 procedure(compute_base), deferred, pass(this) :: compute
65 !! Compute the loss of a model
66 end type base_loss_type
67
68 interface
69 module function compute_base(this, predicted, expected) result(output)
70 !! Compute the loss of a model
71 class(base_loss_type), intent(in), target :: this
72 !! Instance of the physics-informed neural network loss function
73 type(array_type), dimension(:,:), intent(inout), target :: predicted
74 !! Predicted values
75 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
76 expected
77 !! Expected values
78 type(array_type), pointer :: output
79 !! Physics-informed neural network loss
80 end function compute_base
81 end interface
82
83 !-------------------------------------------------------------------------------
84
85 type, extends(base_loss_type) :: bce_loss_type
86 !! Binary cross entropy loss function
87 contains
88 procedure :: compute => compute_bce
89 !! Compute the loss of a model
90 end type bce_loss_type
91
92 interface bce_loss_type
93 !! Interface for binary cross entropy loss function
94 module function setup_loss_bce() result(loss)
95 !! Set up binary cross entropy loss function
96 type(bce_loss_type) :: loss
97 !! Binary cross entropy loss function
98 end function setup_loss_bce
99 end interface bce_loss_type
100
101 !-------------------------------------------------------------------------------
102
103 type, extends(base_loss_type) :: cce_loss_type
104 !! Categorical cross entropy loss function
105 contains
106 procedure :: compute => compute_cce
107 !! Compute the loss of a model
108 end type cce_loss_type
109
110 interface cce_loss_type
111 !! Interface for categorical cross entropy loss function
112 module function setup_loss_cce() result(loss)
113 !! Set up categorical cross entropy loss function
114 type(cce_loss_type) :: loss
115 !! Categorical cross entropy loss function
116 end function setup_loss_cce
117 end interface cce_loss_type
118
119 !-------------------------------------------------------------------------------
120
121 type, extends(base_loss_type) :: mae_loss_type
122 !! Mean absolute error loss function
123 contains
124 procedure :: compute => compute_mae
125 !! Compute the loss of a model
126 end type mae_loss_type
127
128 interface mae_loss_type
129 !! Interface for mean absolute error loss function
130 module function setup_loss_mae() result(loss)
131 !! Set up mean absolute error loss function
132 type(mae_loss_type) :: loss
133 !! Mean absolute error loss function
134 end function setup_loss_mae
135 end interface mae_loss_type
136
137 !-------------------------------------------------------------------------------
138
139 type, extends(base_loss_type) :: mse_loss_type
140 !! Mean squared error loss function
141 contains
142 procedure :: compute => compute_mse
143 !! Compute the loss of a model
144 end type mse_loss_type
145
146 interface mse_loss_type
147 !! Interface for mean squared error loss function
148 module function setup_loss_mse() result(loss)
149 !! Set up mean squared error loss function
150 type(mse_loss_type) :: loss
151 !! Mean squared error loss function
152 end function setup_loss_mse
153 end interface mse_loss_type
154
155 !-------------------------------------------------------------------------------
156
157 type, extends(base_loss_type) :: nll_loss_type
158 !! Negative log likelihood loss function
159 contains
160 procedure :: compute => compute_nll
161 !! Compute the loss of a model
162 end type nll_loss_type
163
164 interface nll_loss_type
165 !! Interface for negative log likelihood loss function
166 module function setup_loss_nll() result(loss)
167 !! Set up negative log likelihood loss function
168 type(nll_loss_type) :: loss
169 !! Negative log likelihood loss function
170 end function setup_loss_nll
171 end interface nll_loss_type
172
173 !-------------------------------------------------------------------------------
174
175 type, extends(base_loss_type) :: huber_loss_type
176 !! Huber loss function
177 real(real32) :: gamma = 1._real32
178 !! Gamma value for the huber loss function
179 contains
180 procedure :: compute => compute_huber
181 !! Compute the loss of a model
182 end type huber_loss_type
183
184 interface huber_loss_type
185 !! Interface for huber loss function
186 module function setup_loss_huber() result(loss)
187 !! Set up huber loss function
188 type(huber_loss_type) :: loss
189 !! Huber loss function
190 end function setup_loss_huber
191 end interface huber_loss_type
192
193 !-------------------------------------------------------------------------------
194
195
196
197 contains
198 !###############################################################################
199 module function setup_loss_bce() result(loss)
200 !! Set up binary cross entropy loss function
201 implicit none
202
203 ! Local variables
204 type(bce_loss_type) :: loss
205 !! Binary cross entropy loss function
206
207 loss%name = 'bce'
208 end function setup_loss_bce
209 !-------------------------------------------------------------------------------
210 module function setup_loss_cce() result(loss)
211 !! Set up categorical cross entropy loss function
212 implicit none
213
214 ! Local variables
215 type(cce_loss_type) :: loss
216 !! Categorical cross entropy loss function
217
218 loss%name = 'cce'
219 end function setup_loss_cce
220 !-------------------------------------------------------------------------------
221 module function setup_loss_mae() result(loss)
222 !! Set up mean absolute error loss function
223 implicit none
224
225 ! Local variables
226 type(mae_loss_type) :: loss
227 !! Mean absolute error loss function
228
229 loss%name = 'mae'
230 end function setup_loss_mae
231 !-------------------------------------------------------------------------------
232 module function setup_loss_mse() result(loss)
233 !! Set up mean squared error loss function
234 implicit none
235
236 ! Local variables
237 type(mse_loss_type) :: loss
238 !! Mean squared error loss function
239
240 loss%name = 'mse'
241 end function setup_loss_mse
242 !-------------------------------------------------------------------------------
243 module function setup_loss_nll() result(loss)
244 !! Set up negative log likelihood loss function
245 implicit none
246
247 ! Local variables
248 type(nll_loss_type) :: loss
249 !! Negative log likelihood loss function
250
251 loss%name = 'nll'
252 end function setup_loss_nll
253 !-------------------------------------------------------------------------------
254 module function setup_loss_huber() result(loss)
255 !! Set up huber loss function
256 implicit none
257
258 ! Local variables
259 type(huber_loss_type) :: loss
260 !! Huber loss function
261
262 loss%name = 'hub'
263 end function setup_loss_huber
264 !###############################################################################
265
266
267 !###############################################################################
268 function compute_bce(this, predicted, expected) result(output)
269 !! Compute the binary cross entropy loss of a model
270 implicit none
271
272 ! Arguments
273 class(bce_loss_type), intent(in), target :: this
274 !! Instance of the physics-informed neural network loss function
275 type(array_type), dimension(:,:), intent(inout), target :: predicted
276 !! Predicted values
277 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
278 expected
279 !! Expected values
280 type(array_type), pointer :: output
281 !! Binary cross entropy loss
282
283 ! Local variables
284 integer :: s, i
285 !! Loop indices
286 type(array_type), pointer :: ptr
287 !! Temporary pointer for calculations
288
289 output => mean(-expected(1,1) * log(predicted(1,1) + this%epsilon), dim=2)
290 if(any(shape(predicted).gt.1))then
291 do s = 1, size(predicted,2)
292 do i = 1, size(predicted,1)
293 if(.not.predicted(i,s)%allocated .or. &
294 .not.expected(i,s)%allocated) cycle
295 ptr => mean(-expected(i,s) * log(predicted(i,s) + this%epsilon), dim=2)
296
297 output => output + ptr
298 end do
299 end do
300 end if
301
302 end function compute_bce
303 !###############################################################################
304
305
306 !###############################################################################
307 function compute_cce(this, predicted, expected) result(output)
308 !! Compute the categorical cross entropy loss of a model
309 implicit none
310
311 ! Arguments
312 class(cce_loss_type), intent(in), target :: this
313 !! Instance of the physics-informed neural network loss function
314 type(array_type), dimension(:,:), intent(inout), target :: predicted
315 !! Predicted values
316 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
317 expected
318 !! Expected values
319 type(array_type), pointer :: output
320 !! Categorical cross entropy loss
321
322 ! Local variables
323 integer :: s, i
324 !! Loop indices
325 type(array_type), pointer :: ptr
326 !! Temporary pointer for calculations
327
328 output => -mean( sum( &
329 expected(1,1) * log(predicted(1,1) + this%epsilon), &
330 dim=1 ), dim=2)
331 if(any(shape(predicted).gt.1))then
332 do s = 1, size(predicted,2)
333 do i = 1, size(predicted,1)
334 if(.not.predicted(i,s)%allocated .or. &
335 .not.expected(i,s)%allocated) cycle
336 ptr => mean( sum( &
337 expected(i,s) * log(predicted(i,s) + this%epsilon), &
338 dim=1 ), dim=2)
339
340 output => output - ptr
341 end do
342 end do
343 end if
344
345 end function compute_cce
346 !###############################################################################
347
348
349 !###############################################################################
350 function compute_mae(this, predicted, expected) result(output)
351 !! Compute the mean absolute error of a model
352 implicit none
353
354 ! Arguments
355 class(mae_loss_type), intent(in), target :: this
356 type(array_type), dimension(:,:), intent(inout), target :: predicted
357 !! Predicted values
358 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
359 expected
360 !! Expected values
361 type(array_type), pointer :: output
362 !! Mean absolute error
363
364 ! Local variables
365 integer :: s, i
366 !! Loop indices
367 type(array_type), pointer :: ptr
368 !! Temporary pointer for calculations
369
370 output => mean( abs( predicted(1,1) - expected(1,1) ) ) / &
371 2._real32
372 if(any(shape(predicted).gt.1))then
373 do s = 1, size(predicted,2)
374 do i = 1, size(predicted,1)
375 if(.not.predicted(i,s)%allocated .or. &
376 .not.expected(i,s)%allocated) cycle
377 ptr => mean( abs( predicted(i,s) - expected(i,s) ) ) / &
378 2._real32
379
380 output => output + ptr
381 end do
382 end do
383 end if
384
385 end function compute_mae
386 !###############################################################################
387
388
389 !###############################################################################
390 function compute_mse(this, predicted, expected) result(output)
391 !! Compute the mean squared error of a model
392 implicit none
393
394 ! Arguments
395 class(mse_loss_type), intent(in), target :: this
396 !! Instance of the mean squared error loss function
397 type(array_type), dimension(:,:), intent(inout), target :: predicted
398 !! Predicted values
399 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
400 expected
401 !! Expected values
402 type(array_type), pointer :: output
403 !! Mean squared error loss
404
405 ! Local variables
406 integer :: s, i
407 !! Loop indices
408 type(array_type), pointer :: ptr
409 !! Temporary pointer for calculations
410
411 output => mean( ( predicted(1,1) - expected(1,1) ) ** 2._real32 ) / &
412 2._real32
413 if(any(shape(predicted).gt.1))then
414 do s = 1, size(predicted,2)
415 do i = 1, size(predicted,1)
416 if(.not.predicted(i,s)%allocated .or. &
417 .not.expected(i,s)%allocated) cycle
418 ptr => mean( ( predicted(i,s) - expected(i,s) ) ** 2._real32 ) / &
419 2._real32
420
421 output => output + ptr
422 end do
423 end do
424 end if
425
426 end function compute_mse
427 !###############################################################################
428
429
430 !###############################################################################
431 function compute_nll(this, predicted, expected) result(output)
432 !! Compute the negative log likelihood of a model
433 implicit none
434
435 ! Arguments
436 class(nll_loss_type), intent(in), target :: this
437 !! Instance of the physics-informed neural network loss function
438 type(array_type), dimension(:,:), intent(inout), target :: predicted
439 !! Predicted values
440 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
441 expected
442 !! Expected values
443 type(array_type), pointer :: output
444 !! Negative log likelihood loss
445
446 ! Local variables
447 integer :: s, i
448 !! Loop indices
449 type(array_type), pointer :: ptr
450 !! Temporary pointer for calculations
451
452 output => mean(-log(expected(1,1) - predicted(1,1) + this%epsilon) )
453 if(any(shape(predicted).gt.1))then
454 do s = 1, size(predicted,2)
455 do i = 1, size(predicted,1)
456 if(.not.predicted(i,s)%allocated .or. &
457 .not.expected(i,s)%allocated) cycle
458 ptr => mean(-log(expected(i,s) - predicted(i,s) + this%epsilon) )
459
460 output => output + ptr
461 end do
462 end do
463 end if
464
465 end function compute_nll
466 !###############################################################################
467
468
469 !###############################################################################
470 function compute_huber(this, predicted, expected) result(output)
471 !! Compute the huber loss of a model
472 implicit none
473
474 ! Arguments
475 class(huber_loss_type), intent(in), target :: this
476 !! Instance of the huber loss function
477 type(array_type), dimension(:,:), intent(inout), target :: predicted
478 !! Predicted values
479 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
480 expected
481 !! Expected values
482 type(array_type), pointer :: output
483 !! Huber loss
484
485 ! Local variables
486 integer :: s, i
487 !! Loop indices
488 type(array_type), pointer :: ptr
489 !! Temporary pointer for calculations
490
491 ptr => predicted(1,1) - expected(1,1)
492 output => mean( huber(predicted(1,1) - expected(1,1), this%gamma) )
493 if(any(shape(predicted).gt.1))then
494 do s = 1, size(predicted,2)
495 do i = 1, size(predicted,1)
496 if(.not.predicted(i,s)%allocated .or. &
497 .not.expected(i,s)%allocated) cycle
498 ptr => predicted(i,s) - expected(i,s)
499
500 output => output + mean( huber(ptr, this%gamma) )
501 end do
502 end do
503 end if
504
505 ! output => merge( &
506 ! 0.5_real32 * (ptr)**2._real32, &
507 ! this%gamma * (abs(ptr) - 0.5_real32 * this%gamma), &
508 ! abs(ptr) .le. this%gamma &
509 ! )
510
511 end function compute_huber
512 !###############################################################################
513
514
515 !###############################################################################
516 module function compute_base(this, predicted, expected) result(output)
517 !! Placeholder for compute function in base_loss_type
518 implicit none
519
520 ! Arguments
521 class(base_loss_type), intent(in), target :: this
522 !! Instance of the base loss function
523 type(array_type), dimension(:,:), intent(inout), target :: predicted
524 !! Predicted values
525 type(array_type), dimension(size(predicted,1),size(predicted,2)), intent(in) :: &
526 expected
527 !! Expected values
528 type(array_type), pointer :: output
529 !! Loss value
530
531 end function compute_base
532 !###############################################################################
533
534 end module athena__loss
535