Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/loss.py: 100%

7 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Quentin Gabot 

4 

5# The following is an adaptation of the initialization code from 

6# PyTorch library for the complex-valued neural networks 

7 

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files (the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions: 

14 

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17 

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

24# SOFTWARE. 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29 

30 

31class ComplexMSELoss(nn.modules.loss._Loss): 

32 r"""A custom implementation of Mean Squared Error (MSE) Loss for complex-valued inputs. 

33 

34 This loss function is designed to compute the mean squared error between two sets of complex-valued tensors. Unlike the standard MSE loss which directly computes the square of the difference between predicted values and true values for real numbers, this implementation adapts to complex numbers by calculating the squared magnitude of the difference between the complex predicted and true values before averaging. 

35 

36 For complex-valued numbers, the MSE is defined as: 

37 MSE = mean(|y_true - y_pred|^2) 

38 where |y_true - y_pred| denotes the magnitude of the complex difference between the true and predicted values. This formula ensures that both the real and imaginary parts of the complex numbers are considered in the loss calculation. 

39 

40 For real-valued numbers, the standard MSE formula is: 

41 MSE = mean((y_true - y_pred)^2) 

42 where (y_true - y_pred)^2 is the square of the difference between the true and predicted values. 

43 

44 Parameters: 

45 size_average (bool, optional): Deprecated (use reduction). By default, the losses are averaged over each loss element in the batch. Note: when (reduce is False), size_average is ignored. Default is None. 

46 reduce (bool, optional): Deprecated (use reduction). By default, the losses are averaged or summed over observations for each minibatch depending on reduction. Note: when (reduce is False), size_average is ignored. Default is None. 

47 reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Default is 'mean'. 

48 

49 Inputs: 

50 y_pred (torch.Tensor): The predicted values, can be a complex tensor. 

51 y_true (torch.Tensor): The ground truth values, must have the same shape as the predicted values, can be a complex tensor. 

52 

53 Returns: 

54 torch.Tensor: The calculated mean squared error loss. 

55 

56 Examples: 

57 >>> loss = ComplexMSELoss() 

58 >>> y_pred = torch.tensor([1+1j, 2+2j], dtype=torch.complex64) 

59 >>> y_true = torch.tensor([1+2j, 2+3j], dtype=torch.complex64) 

60 >>> output = loss(y_pred, y_true) 

61 >>> print(output) 

62 """ 

63 

64 def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None: 

65 super().__init__(size_average, reduce, reduction) 

66 

67 def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 

68 # Calculate Mean Square Loss for complex numbers 

69 return torch.mean(torch.square(torch.abs(y_true - y_pred)))