# MIT License
# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Standard imports
from abc import ABC, abstractmethod
from typing import Tuple, Union, Optional, Dict
from types import NoneType, ModuleType
# External imports
import torch
import numpy as np
from PIL import Image
# Internal imports
import torchcvnn.transforms.functional as F
class BaseTransform(ABC):
"""Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors.
This class serves as a template for implementing transforms that can be applied to both numpy arrays
and PyTorch tensors while maintaining consistent behavior.
Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width),
they will be converted to (1, Height, Width).
Args:
dtype (str, optional): Data type to convert inputs to. Must be one of:
'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed.
Default: None.
Raises:
AssertionError: If dtype is not a string or not one of the allowed types.
ValueError: If input is neither a numpy array nor a PyTorch tensor.
Methods:
__call__(x): Apply the transform to the input array/tensor.
__call_numpy__(x): Abstract method to implement numpy array transform.
__call_torch__(x): Abstract method to implement PyTorch tensor transform.
Example:
>>> class MyTransform(BaseTransform):
>>> def __call_numpy__(self, x):
>>> # Implement numpy transform
>>> pass
>>> def __call_torch__(self, x):
>>> # Implement torch transform
>>> pass
>>> transform = MyTransform(dtype='float32')
>>> output = transform(input_data) # Works with both numpy arrays and torch tensors
"""
def __init__(self, dtype: str | NoneType = None) -> None:
if dtype is not None:
assert isinstance(dtype, str), "dtype should be a string"
assert dtype in ["float32", "float64", "complex64", "complex128"], "dtype should be one of float32, float64, complex64, complex128"
self.np_dtype = getattr(np, dtype)
self.torch_dtype = getattr(torch, dtype)
def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
"""Apply transform to input."""
x = F.check_input(x)
if isinstance(x, np.ndarray):
return self.__call_numpy__(x)
elif isinstance(x, torch.Tensor):
return self.__call_torch__(x)
@abstractmethod
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
"""Apply transform to numpy array."""
raise NotImplementedError
@abstractmethod
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply transform to torch tensor."""
raise NotImplementedError
[docs]
class LogAmplitude(BaseTransform):
"""This transform applies a logarithmic scaling to the amplitude/magnitude of complex values
while optionally preserving the phase information. The amplitude is first clipped to
[min_value, max_value] range, then log10-transformed and normalized to [0,1] range.
The transformation follows these steps:
1. Extract amplitude and phase from complex input
2. Clip amplitude between min_value and max_value
3. Apply log10 transform and normalize to [0,1]
4. Optionally recombine with original phase
Args:
min_value (int | float, optional): Minimum amplitude value for clipping.
Values below this will be clipped up. Defaults to 0.02.
max_value (int | float, optional): Maximum amplitude value for clipping.
Values above this will be clipped down. Defaults to 40.
keep_phase (bool, optional): Whether to preserve phase information.
If True, returns complex output with transformed amplitude and original phase.
If False, returns just the transformed amplitude. Defaults to True.
Returns:
np.ndarray | torch.Tensor: Transformed tensor with same shape as input.
If keep_phase=True: Complex tensor with log-scaled amplitude and original phase
If keep_phase=False: Real tensor with just the log-scaled amplitude
Example:
>>> transform = LogAmplitude(min_value=0.01, max_value=100)
>>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1]
Note:
The transform works with both NumPy arrays and PyTorch tensors through
separate internal implementations (__call_numpy__ and __call_torch__).
"""
[docs]
def __init__(self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True) -> None:
self.min_value = min_value
self.max_value = max_value
self.keep_phase = keep_phase
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return F.log_normalize_amplitude(x, np, self.keep_phase, self.min_value, self.max_value)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return F.log_normalize_amplitude(x, torch, self.keep_phase, self.min_value, self.max_value)
[docs]
class Amplitude(BaseTransform):
"""Transform a complex-valued tensor into its amplitude/magnitude.
This transform computes the absolute value (magnitude) of complex input data,
converting complex values to real values.
Args:
dtype (str): Data type for the output ('float32', 'float64', etc)
Returns:
np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes,
with same shape as input but real-valued type specified by dtype.
"""
[docs]
def __init__(self, dtype: str) -> None:
super().__init__(dtype)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return torch.abs(x).to(self.torch_dtype)
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return np.abs(x).astype(self.np_dtype)
[docs]
class RealImaginary(BaseTransform):
"""Transform a complex-valued tensor into its real and imaginary components.
This transform separates a complex-valued tensor into its real and imaginary parts,
stacking them along a new channel dimension. The output tensor has twice the number
of channels as the input.
Returns:
np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts,
with shape (2*C, H, W) where C is the original number of channels.
"""
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW
x = x.flatten(0, 1) # 2CHW -> 2C*H*W
return x
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW
x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W
return x
[docs]
class RandomPhase(BaseTransform):
"""Randomly phase-shifts complex-valued input data.
This transform applies a random phase shift to complex-valued input tensors/arrays by
multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π]
or [-π, π] if centering is enabled.
Args:
dtype : str
Data type for the output. Must be one of the supported complex dtypes.
centering : bool, optional.
If True, centers the random phase distribution around 0 by subtracting π from
the generated phases. Default is False.
Returns
torch.Tensor or numpy.ndarray
Phase-shifted complex-valued data with the same shape as input.
Examples
>>> transform = RandomPhase(dtype='complex64')
>>> x = torch.ones(3,3, dtype=torch.complex64)
>>> output = transform(x) # Applies random phase shifts
Notes
- Input data must be complex-valued
- The output maintains the same shape and complex dtype as input
- Phase shifts are uniformly distributed in:
- [0, 2π] when centering=False
- [-π, π] when centering=True
"""
[docs]
def __init__(self, dtype: str, centering: bool = False) -> None:
super().__init__(dtype)
self.centering = centering
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
phase = torch.rand_like(x) * 2 * torch.pi
if self.centering:
phase = phase - torch.pi
return (x * torch.exp(1j * phase)).to(self.torch_dtype)
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
phase = np.random.rand(*x.shape) * 2 * np.pi
if self.centering:
phase = phase - np.pi
return (x * np.exp(1j * phase)).astype(self.np_dtype)
class ToReal:
"""Extracts the real part of a complex-valued input tensor.
The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers
and returns only their real parts. If the input is already real-valued, it remains unchanged.
Returns:
np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
the real components of each element.
Example:
>>> to_real = ToReal()
>>> output = to_real(complex_tensor)
"""
def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
return x.real
class ToImaginary:
"""Extracts the imaginary part of a complex-valued input tensor.
The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers
and returns only their imaginary parts. If the input is already real-valued, it remains unchanged.
Returns:
np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
the imaginary components of each element.
Example:
>>> to_imaginary = ToImaginary()
>>> output = to_imaginary(complex_tensor)
"""
def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
return x.imag
class FFT2(BaseTransform):
"""Applies 2D Fast Fourier Transform (FFT) to the input.
This transform computes the 2D FFT along specified dimensions of the input array/tensor.
It applies FFT2 and shifts zero-frequency components to the center.
Args
axis : Tuple[int, ...], optional
The axes over which to compute the FFT. Default is (-2, -1).
Returns
numpy.ndarray or torch.Tensor
The 2D Fourier transformed input with zero-frequency components centered.
Output has the same shape as input.
Notes
- Transform is applied along specified dimensions (`axis`).
"""
def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
self.axis = axis
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return F.applyfft2_np(x, axis=self.axis)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return F.applyfft2_torch(x, dim=self.axis)
class IFFT2(BaseTransform):
"""Applies 2D inverse Fast Fourier Transform (IFFT) to the input.
This transform computes the 2D IFFT along the last two dimensions of the input array/tensor.
It applies inverse FFT shift before IFFT2.
Args
axis : Tuple[int, ...], optional
The axes over which to compute the FFT. Default is (-2, -1).
Returns
numpy.ndarray or torch.Tensor:
The inverse Fourier transformed input.
Output has the same shape as input.
Notes:
- Transform is applied along specified dimensions (`axis`).
"""
def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
self.axis = axis
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return F.applyifft2_np(x, axis=self.axis)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return F.applyifft2_torch(x, dim=self.axis)
class PadIfNeeded(BaseTransform):
"""Pad an image if its dimensions are smaller than specified minimum dimensions.
This transform pads images that are smaller than given minimum dimensions by adding
padding according to the specified border mode. The padding is added symmetrically
on both sides to reach the minimum dimensions when possible. If the minimum required
dimension (height or width) is uneven, the right and the bottom sides will receive
an extra padding of 1 compared to the left and the top sides.
Args:
min_height (int): Minimum height requirement for the image
min_width (int): Minimum width requirement for the image
border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'.
pad_value (float): Value for constant padding (if applicable). Default is 0.
Returns:
np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width.
Original image if no padding is required.
Example:
>>> transform = PadIfNeeded(min_height=256, min_width=256)
>>> padded_image = transform(small_image) # Pads if image is smaller than 256x256
"""
def __init__(
self,
min_height: int,
min_width: int,
border_mode: str = "constant",
pad_value: float = 0
) -> None:
self.min_height = min_height
self.min_width = min_width
self.border_mode = border_mode
self.pad_value = pad_value
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value)
class CenterCrop(BaseTransform):
"""Center crops an input array/tensor to the specified size.
This transform extracts a centered rectangular region from the input array/tensor
with the specified dimensions. The crop is centered on both height and width axes.
Args:
height (int): Target height of the cropped output
width (int): Target width of the cropped output
Returns:
np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width),
where C is the number of channels
Examples:
>>> transform = CenterCrop(height=224, width=224)
>>> output = transform(input_tensor) # Center crops to 224x224
Notes:
- If input is smaller than crop size, it will return the original input
- Crop is applied identically to all channels
- Uses functional.center_crop() implementation for both numpy and torch
"""
def __init__(self, height: int, width: int) -> None:
self.height = height
self.width = width
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return F.center_crop(x, self.height, self.width)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return F.center_crop(x, self.height, self.width)
[docs]
class FFTResize(BaseTransform):
"""Resizes an input image in spectral domain with Fourier Transformations.
This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes,
followed by padding or center cropping to achieve the target size, then applies
an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency
between original and resized images.
Args:
size: Tuple[int, int]
Target dimensions (height, width) for resizing.
axis: Tuple[int, ...], optional
The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW,
it corresponds to the Height and Width axes.
scale: bool, optional
If True, scales the output amplitudes to maintain energy consistency with
respect to input size. Default is False.
dtype: torch.dtype or numpy.dtype, optional
Output data type. If None, maintains the input data type.
For PyTorch tensors: torch.complex64 or torch.complex128
For NumPy arrays: numpy.complex64 or numpy.complex128
Returns:
numpy.ndarray or torch.Tensor
Resized image as a complex-valued array/tensor, maintaining shape (C, height, width).
Examples:
>>> transform = FFTResize((128, 128))
>>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT
Notes:
- Input must be a multi-dimensional array/tensor of shape Channel x Height x Width.
- Spectral domain resizing preserves frequency characteristics better than spatial interpolation
- Operates on complex-valued data, preserving phase information
- Memory efficient for large downsampling ratios
- Based on the Fourier Transform properties of scaling and periodicity
- The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data,
it is recommended to call ToReal after applying this transform.
"""
[docs]
def __init__(
self,
size: Tuple[int, ...],
axis: Tuple[int, ...] = (-2, -1),
scale: bool = False,
dtype: Optional[str] = "complex64"
) -> None:
if dtype is None or "complex" not in str(dtype):
dtype = "complex64"
super().__init__(dtype)
assert isinstance(size, Tuple), "size must be a tuple"
assert isinstance(axis, Tuple), "axis must be a tuple"
self.height = size[0]
self.width = size[1]
self.axis = axis
self.scale = scale
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
original_size = x.shape[1] * x.shape[2]
target_size = self.height * self.width
x = F.applyfft2_np(x, axis=self.axis)
x = F.padifneeded(x, self.height, self.width)
x = F.center_crop(x, self.height, self.width)
x = F.applyifft2_np(x, axis=self.axis)
if self.scale:
return x * target_size / original_size
return x.astype(self.np_dtype)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
original_size = x.shape[1] * x.shape[2]
target_size = self.height * self.width
x = F.applyfft2_torch(x, dim=self.axis)
x = F.padifneeded(x, self.height, self.width)
x = F.center_crop(x, self.height, self.width)
x = F.applyifft2_torch(x, dim=self.axis)
if self.scale:
return x * target_size / original_size
return x.to(self.torch_dtype)
[docs]
class SpatialResize:
"""
Resize a complex tensor to a given size. The resize is performed in the image space
using a Bicubic interpolation.
Arguments:
size: The target size of the resized tensor.
"""
[docs]
def __init__(self, size):
self.size = size
def __call__(
self, array: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
is_torch = False
if isinstance(array, torch.Tensor):
is_torch = True
array = array.numpy()
real_part = array.real
imaginary_part = array.imag
def zoom(array):
# Convert the numpy array to a PIL image
image = Image.fromarray(array)
# Resize the image
image = image.resize((self.size[1], self.size[0]))
# Convert the PIL image back to a numpy array
array = np.array(image)
return array
if len(array.shape) == 2:
# We have a two dimensional tensor
resized_real = zoom(real_part)
resized_imaginary = zoom(imaginary_part)
else:
# We have three dimensions and therefore
# apply the resize to each channel iteratively
# We assume the first dimension is the channel
resized_real = []
resized_imaginary = []
for real, imaginary in zip(real_part, imaginary_part):
resized_real.append(zoom(real))
resized_imaginary.append(zoom(imaginary))
resized_real = np.stack(resized_real)
resized_imaginary = np.stack(resized_imaginary)
resized_array = resized_real + 1j * resized_imaginary
# Convert the resized tensor back to a torch tensor if necessary
if is_torch:
resized_array = torch.as_tensor(resized_array)
return resized_array
class PolSAR(BaseTransform):
"""Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions.
This class provides functionality to convert between different channel representations of PolSAR data,
supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors.
If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array.
Args:
out_channel (int): Desired number of output channels (1, 2, 3, or 4)
Supported conversions:
- 1 channel -> 1 channel: Identity
- 2 channels -> 1 or 2 channels
- 4 channels -> 1, 2, 3, or 4 channels where:
- 1 channel: Returns first channel only
- 2 channels: Returns [HH, VV] channels
- 3 channels: Returns [HH, (HV+VH)/2, VV]
- 4 channels: Returns all channels [HH, HV, VH, VV]
Raises:
ValueError: If the requested channel conversion is invalid or not supported
Example:
>>> transform = PolSAR(out_channel=3)
>>> # For 4-channel input [HH, HV, VH, VV]
>>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV]
Note:
- Input data should have format Channels x Height x Width (CHW).
- By default, PolSAR always return HH polarization if out_channel is 1.
"""
def __init__(self, out_channel: int) -> None:
self.out_channel = out_channel
def _handle_single_channel(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor:
return x if out_channels == 1 else None
def _handle_two_channels(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor:
if out_channels == 2:
return x
elif out_channels == 1:
return x[0:1]
return None
def _handle_four_channels(
self,
x: np.ndarray | torch.Tensor,
out_channels: int,
backend: ModuleType
) -> np.ndarray | torch.Tensor:
channel_maps = {
1: lambda: x[0:1],
2: lambda: backend.stack((x[0], x[3])),
3: lambda: backend.stack((
x[0],
0.5 * (x[1] + x[2]),
x[3]
)),
4: lambda: x
}
return channel_maps.get(out_channels, lambda: None)()
def _convert_channels(
self,
x: np.ndarray | torch.Tensor,
out_channels: int,
backend: ModuleType
) -> np.ndarray | torch.Tensor:
handlers = {
1: self._handle_single_channel,
2: self._handle_two_channels,
4: lambda x, o: self._handle_four_channels(x, o, backend)
}
result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels)
if result is None:
raise ValueError(f"Invalid conversion: {x.shape[0]} -> {out_channels} channels")
return result
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return self._convert_channels(x, self.out_channel, np)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return self._convert_channels(x, self.out_channel, torch)
def __call__(self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]) -> np.ndarray | torch.Tensor:
x = F.polsar_dict_to_array(x)
return super().__call__(x)
[docs]
class Unsqueeze(BaseTransform):
"""Add a singleton dimension to the input array/tensor.
This transform inserts a new axis at the specified position, increasing
the dimensionality of the input by one.
Args:
dim (int): Position where new axis should be inserted.
Returns:
np.ndarray | torch.Tensor: Input with new singleton dimension added.
Shape will be same as input but with a 1 inserted at position dim.
Example:
>>> transform = Unsqueeze(dim=0)
>>> x = torch.randn(3,4) # Shape (3,4)
>>> y = transform(x) # Shape (1,3,4)
"""
[docs]
def __init__(self, dim: int) -> None:
self.dim = dim
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return np.expand_dims(x, axis=self.dim)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return x.unsqueeze(dim=self.dim)
[docs]
class ToTensor(BaseTransform):
"""Converts numpy array or torch tensor to torch tensor of specified dtype.
This transform converts input data to a PyTorch tensor with the specified data type.
It handles both numpy arrays and existing PyTorch tensors as input.
Args:
dtype (str): Target data type for the output tensor. Should be one of PyTorch's
supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.)
Returns:
torch.Tensor: The converted tensor with the specified dtype.
Example:
>>> transform = ToTensor(dtype='float32')
>>> x_numpy = np.array([1, 2, 3])
>>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor
>>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32)
>>> x_converted = transform(x_existing) # converts to torch.FloatTensor
"""
[docs]
def __init__(self, dtype: str) -> None:
super().__init__(dtype)
def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
return torch.as_tensor(x, dtype=self.torch_dtype)
def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
return x.to(self.torch_dtype)