Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / transforms / transforms.py: 71%
244 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-21 08:33 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-21 08:33 +0000
1# MIT License
3# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from abc import ABC, abstractmethod
25from typing import Tuple, Union, Optional, Dict
26from types import ModuleType
28# External imports
29import torch
30import numpy as np
31from PIL import Image
33# Internal imports
34import torchcvnn.transforms.functional as F
37class BaseTransform(ABC):
38 """Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors.
39 This class serves as a template for implementing transforms that can be applied to both numpy arrays
40 and PyTorch tensors while maintaining consistent behavior.
41 Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width),
42 they will be converted to (1, Height, Width).
44 Args:
45 dtype (str, optional): Data type to convert inputs to. Must be one of:
46 'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed.
47 Default: None.
49 Raises:
50 AssertionError: If dtype is not a string or not one of the allowed types.
51 ValueError: If input is neither a numpy array nor a PyTorch tensor.
53 Methods:
54 __call__(x): Apply the transform to the input array/tensor.
55 __call_numpy__(x): Abstract method to implement numpy array transform.
56 __call_torch__(x): Abstract method to implement PyTorch tensor transform.
58 Example:
59 >>> class MyTransform(BaseTransform):
60 >>> def __call_numpy__(self, x):
61 >>> # Implement numpy transform
62 >>> pass
63 >>> def __call_torch__(self, x):
64 >>> # Implement torch transform
65 >>> pass
66 >>> transform = MyTransform(dtype='float32')
67 >>> output = transform(input_data) # Works with both numpy arrays and torch tensors
68 """
70 def __init__(self, dtype: str = None) -> None:
71 if dtype is not None:
72 assert isinstance(dtype, str), "dtype should be a string"
73 assert dtype in [
74 "float32",
75 "float64",
76 "complex64",
77 "complex128",
78 ], "dtype should be one of float32, float64, complex64, complex128"
79 self.np_dtype = getattr(np, dtype)
80 self.torch_dtype = getattr(torch, dtype)
82 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
83 """Apply transform to input."""
84 x = F.check_input(x)
85 if isinstance(x, np.ndarray):
86 return self.__call_numpy__(x)
87 elif isinstance(x, torch.Tensor):
88 return self.__call_torch__(x)
90 @abstractmethod
91 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
92 """Apply transform to numpy array."""
93 raise NotImplementedError
95 @abstractmethod
96 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
97 """Apply transform to torch tensor."""
98 raise NotImplementedError
101class LogAmplitude(BaseTransform):
102 """This transform applies a logarithmic scaling to the amplitude/magnitude of complex values
103 while optionally preserving the phase information. The amplitude is first clipped to
104 [min_value, max_value] range, then log10-transformed and normalized to [0,1] range.
106 The transformation follows these steps:
107 1. Extract amplitude and phase from complex input
108 2. Clip amplitude between min_value and max_value
109 3. Apply log10 transform and normalize to [0,1]
110 4. Optionally recombine with original phase
112 Args:
113 min_value (int | float, optional): Minimum amplitude value for clipping.
114 Values below this will be clipped up. Defaults to 0.02.
115 max_value (int | float, optional): Maximum amplitude value for clipping.
116 Values above this will be clipped down. Defaults to 40.
117 keep_phase (bool, optional): Whether to preserve phase information.
118 If True, returns complex output with transformed amplitude and original phase.
119 If False, returns just the transformed amplitude. Defaults to True.
120 Returns:
121 np.ndarray | torch.Tensor: Transformed tensor with same shape as input.
122 If keep_phase=True: Complex tensor with log-scaled amplitude and original phase
123 If keep_phase=False: Real tensor with just the log-scaled amplitude
124 Example:
125 >>> transform = LogAmplitude(min_value=0.01, max_value=100)
126 >>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1]
127 Note:
128 The transform works with both NumPy arrays and PyTorch tensors through
129 separate internal implementations (__call_numpy__ and __call_torch__).
130 """
132 def __init__(
133 self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True
134 ) -> None:
135 self.min_value = min_value
136 self.max_value = max_value
137 self.keep_phase = keep_phase
139 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
140 return F.log_normalize_amplitude(
141 x, np, self.keep_phase, self.min_value, self.max_value
142 )
144 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
145 return F.log_normalize_amplitude(
146 x, torch, self.keep_phase, self.min_value, self.max_value
147 )
149class Normalize(BaseTransform):
150 """Per-channel 2x2 normalization of [Re, Im].
151 This transform normalizes complex-valued input data by centering and scaling it.
152 It supports both numpy arrays and PyTorch tensors as input.
153 Args:
154 means (array-like): Per-channel means for centering. Shape (C, 2) where C is number of channels.
155 covs (array-like): Per-channel 2x2 covariance matrices for scaling. Shape (C, 2, 2).
156 Covariance matrices must be symmetric positive definite.
157 Returns:
158 np.ndarray | torch.Tensor: Normalized data with same shape as input.
159 Each channel is independently centered and scaled.
160 Example:
161 >>> means = [[0,0], [1,1]]
162 >>> covs = [[[1,0],[0,1]], [[2,0],[0,2]]]
163 >>> transform = Normalize(means, covs)
164 >>> output = transform(input_data) # Normalizes each channel independently
165 """
167 def __init__(self, means, covs, eps=1e-12):
168 # means: (C,2) ; covs: (C,2,2)
169 self.means = np.asarray(means, dtype=np.float64)
170 self.covs = np.asarray(covs, dtype=np.float64)
171 assert self.means.ndim == 2 and self.means.shape[1] == 2
172 assert self.covs.ndim == 3 and self.covs.shape[1:] == (2,2)
173 self.num_channels = self.means.shape[0]
175 # NEW: precompute per-channel whitening matrices W = Σ^{-1/2} once
176 covs_sym = 0.5 * (self.covs + np.swapaxes(self.covs, -1, -2)) # NEW
177 w, V = np.linalg.eigh(covs_sym) # NEW
178 w = np.maximum(w, eps) # NEW
179 Dinv = np.zeros_like(self.covs) # NEW
180 Dinv[:,0,0] = 1.0 / np.sqrt(w[:,0])
181 Dinv[:,1,1] = 1.0 / np.sqrt(w[:,1])
182 self.W_np = V @ Dinv @ np.transpose(V, (0,2,1)) # NEW
183 super().__init__()
185 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
186 if x.ndim != 3 or not np.iscomplexobj(x):
187 raise ValueError("Expect complex CHW ndarray.")
188 C, H, W = x.shape
189 if C != self.num_channels:
190 raise ValueError("Channel mismatch.")
192 # Stack to (C,2,H,W)
193 Z = np.stack([x.real, x.imag], axis=1) # FIX: (C,2,H,W)
194 Zc = Z - self.means[:, :, None, None] # center
196 # FIX: correct contraction over the 2-dim
197 Y = np.einsum('cij,cjhw->cihw', self.W_np, Zc) # (C,2,H,W)
199 return Y[:,0] + 1j * Y[:,1]
201 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
202 if x.dim() != 3 or not x.is_complex():
203 raise ValueError("Expect complex CHW tensor.")
204 C, H, W = x.shape
205 if C != self.num_channels:
206 raise ValueError("Channel mismatch.")
208 # Move precomputed params to device/dtype
209 Wc = torch.as_tensor(self.W_np, device=x.device, dtype=x.real.dtype) # NEW
210 muc = torch.as_tensor(self.means, device=x.device, dtype=x.real.dtype) # NEW
212 # (C,2,H,W)
213 Z = torch.stack([x.real, x.imag], dim=1) # FIX
214 Zc = Z - muc.view(C,2,1,1)
216 # FIX: correct contraction over 2-dim
217 Y = torch.einsum('cij,cjhw->cihw', Wc, Zc)
219 return torch.complex(Y[:,0], Y[:,1])
221class Amplitude(BaseTransform):
222 """Transform a complex-valued tensor into its amplitude/magnitude.
224 This transform computes the absolute value (magnitude) of complex input data,
225 converting complex values to real values.
227 Args:
228 dtype (str): Data type for the output ('float32', 'float64', etc)
230 Returns:
231 np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes,
232 with same shape as input but real-valued type specified by dtype.
233 """
235 def __init__(self, dtype: str) -> None:
236 super().__init__(dtype)
238 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
239 return torch.abs(x).to(self.torch_dtype)
241 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
242 return np.abs(x).astype(self.np_dtype)
245class RealImaginary(BaseTransform):
246 """Transform a complex-valued tensor into its real and imaginary components.
248 This transform separates a complex-valued tensor into its real and imaginary parts,
249 stacking them along a new channel dimension. The output tensor has twice the number
250 of channels as the input.
252 Returns:
253 np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts,
254 with shape (2*C, H, W) where C is the original number of channels.
255 """
257 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
258 x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW
259 x = x.flatten(0, 1) # 2CHW -> 2C*H*W
260 return x
262 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
263 x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW
264 x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W
265 return x
268class RandomPhase(BaseTransform):
269 """Randomly phase-shifts complex-valued input data.
270 This transform applies a random phase shift to complex-valued input tensors/arrays by
271 multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π]
272 or [-π, π] if centering is enabled.
273 Args:
274 dtype : str
275 Data type for the output. Must be one of the supported complex dtypes.
276 centering : bool, optional.
277 If True, centers the random phase distribution around 0 by subtracting π from
278 the generated phases. Default is False.
279 Returns
280 torch.Tensor or numpy.ndarray
281 Phase-shifted complex-valued data with the same shape as input.
283 Examples
284 >>> transform = RandomPhase(dtype='complex64')
285 >>> x = torch.ones(3,3, dtype=torch.complex64)
286 >>> output = transform(x) # Applies random phase shifts
288 Notes
289 - Input data must be complex-valued
290 - The output maintains the same shape and complex dtype as input
291 - Phase shifts are uniformly distributed in:
292 - [0, 2π] when centering=False
293 - [-π, π] when centering=True
294 """
296 def __init__(self, dtype: str, centering: bool = False) -> None:
297 super().__init__(dtype)
298 self.centering = centering
300 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
301 phase = torch.rand_like(x) * 2 * torch.pi
302 if self.centering:
303 phase = phase - torch.pi
304 return (x * torch.exp(1j * phase)).to(self.torch_dtype)
306 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
307 phase = np.random.rand(*x.shape) * 2 * np.pi
308 if self.centering:
309 phase = phase - np.pi
310 return (x * np.exp(1j * phase)).astype(self.np_dtype)
313class ToReal:
314 """Extracts the real part of a complex-valued input tensor.
316 The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers
317 and returns only their real parts. If the input is already real-valued, it remains unchanged.
319 Returns:
320 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
321 the real components of each element.
323 Example:
324 >>> to_real = ToReal()
325 >>> output = to_real(complex_tensor)
326 """
328 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
329 return x.real
332class ToImaginary:
333 """Extracts the imaginary part of a complex-valued input tensor.
335 The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers
336 and returns only their imaginary parts. If the input is already real-valued, it remains unchanged.
338 Returns:
339 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
340 the imaginary components of each element.
342 Example:
343 >>> to_imaginary = ToImaginary()
344 >>> output = to_imaginary(complex_tensor)
345 """
347 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
348 return x.imag
351class FFT2(BaseTransform):
352 """Applies 2D Fast Fourier Transform (FFT) to the input.
353 This transform computes the 2D FFT along specified dimensions of the input array/tensor.
354 It applies FFT2 and shifts zero-frequency components to the center.
356 Args
357 axis : Tuple[int, ...], optional
358 The axes over which to compute the FFT. Default is (-2, -1).
360 Returns
361 numpy.ndarray or torch.Tensor
362 The 2D Fourier transformed input with zero-frequency components centered.
363 Output has the same shape as input.
365 Notes
366 - Transform is applied along specified dimensions (`axis`).
367 """
369 def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
370 self.axis = axis
372 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
373 return F.applyfft2_np(x, axis=self.axis)
375 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
376 return F.applyfft2_torch(x, dim=self.axis)
379class IFFT2(BaseTransform):
380 """Applies 2D inverse Fast Fourier Transform (IFFT) to the input.
381 This transform computes the 2D IFFT along the last two dimensions of the input array/tensor.
382 It applies inverse FFT shift before IFFT2.
384 Args
385 axis : Tuple[int, ...], optional
386 The axes over which to compute the FFT. Default is (-2, -1).
388 Returns
389 numpy.ndarray or torch.Tensor:
390 The inverse Fourier transformed input.
391 Output has the same shape as input.
393 Notes:
394 - Transform is applied along specified dimensions (`axis`).
395 """
397 def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
398 self.axis = axis
400 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
401 return F.applyifft2_np(x, axis=self.axis)
403 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
404 return F.applyifft2_torch(x, dim=self.axis)
407class PadIfNeeded(BaseTransform):
408 """Pad an image if its dimensions are smaller than specified minimum dimensions.
410 This transform pads images that are smaller than given minimum dimensions by adding
411 padding according to the specified border mode. The padding is added symmetrically
412 on both sides to reach the minimum dimensions when possible. If the minimum required
413 dimension (height or width) is uneven, the right and the bottom sides will receive
414 an extra padding of 1 compared to the left and the top sides.
416 Args:
417 min_height (int): Minimum height requirement for the image
418 min_width (int): Minimum width requirement for the image
419 border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'.
420 pad_value (float): Value for constant padding (if applicable). Default is 0.
422 Returns:
423 np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width.
424 Original image if no padding is required.
426 Example:
427 >>> transform = PadIfNeeded(min_height=256, min_width=256)
428 >>> padded_image = transform(small_image) # Pads if image is smaller than 256x256
429 """
431 def __init__(
432 self,
433 min_height: int,
434 min_width: int,
435 border_mode: str = "constant",
436 pad_value: float = 0,
437 ) -> None:
438 self.min_height = min_height
439 self.min_width = min_width
440 self.border_mode = border_mode
441 self.pad_value = pad_value
443 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
444 return F.padifneeded(
445 x, self.min_height, self.min_width, self.border_mode, self.pad_value
446 )
448 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
449 return F.padifneeded(
450 x, self.min_height, self.min_width, self.border_mode, self.pad_value
451 )
454class CenterCrop(BaseTransform):
455 """Center crops an input array/tensor to the specified size.
457 This transform extracts a centered rectangular region from the input array/tensor
458 with the specified dimensions. The crop is centered on both height and width axes.
460 Args:
461 height (int): Target height of the cropped output
462 width (int): Target width of the cropped output
464 Returns:
465 np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width),
466 where C is the number of channels
468 Examples:
469 >>> transform = CenterCrop(height=224, width=224)
470 >>> output = transform(input_tensor) # Center crops to 224x224
472 Notes:
473 - If input is smaller than crop size, it will return the original input
474 - Crop is applied identically to all channels
475 - Uses functional.center_crop() implementation for both numpy and torch
476 """
478 def __init__(self, height: int, width: int) -> None:
479 self.height = height
480 self.width = width
482 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
483 return F.center_crop(x, self.height, self.width)
485 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
486 return F.center_crop(x, self.height, self.width)
489class FFTResize(BaseTransform):
490 """Resizes an input image in spectral domain with Fourier Transformations.
492 This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes,
493 followed by padding or center cropping to achieve the target size, then applies
494 an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency
495 between original and resized images.
497 Args:
498 size: Tuple[int, int]
499 Target dimensions (height, width) for resizing.
500 axis: Tuple[int, ...], optional
501 The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW,
502 it corresponds to the Height and Width axes.
503 scale: bool, optional
504 If True, scales the output amplitudes to maintain energy consistency with
505 respect to input size. Default is False.
506 dtype: torch.dtype or numpy.dtype, optional
507 Output data type. If None, maintains the input data type.
508 For PyTorch tensors: torch.complex64 or torch.complex128
509 For NumPy arrays: numpy.complex64 or numpy.complex128
511 Returns:
512 numpy.ndarray or torch.Tensor
513 Resized image as a complex-valued array/tensor, maintaining shape (C, height, width).
515 Examples:
516 >>> transform = FFTResize((128, 128))
517 >>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT
519 Notes:
520 - Input must be a multi-dimensional array/tensor of shape Channel x Height x Width.
521 - Spectral domain resizing preserves frequency characteristics better than spatial interpolation
522 - Operates on complex-valued data, preserving phase information
523 - Memory efficient for large downsampling ratios
524 - Based on the Fourier Transform properties of scaling and periodicity
525 - The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data,
526 it is recommended to call ToReal after applying this transform.
527 """
529 def __init__(
530 self,
531 size: Tuple[int, ...],
532 axis: Tuple[int, ...] = (-2, -1),
533 scale: bool = False,
534 dtype: Optional[str] = "complex64",
535 ) -> None:
536 if dtype is None or "complex" not in str(dtype):
537 dtype = "complex64"
539 super().__init__(dtype)
540 assert isinstance(size, Tuple), "size must be a tuple"
541 assert isinstance(axis, Tuple), "axis must be a tuple"
542 self.height = size[0]
543 self.width = size[1]
544 self.axis = axis
545 self.scale = scale
547 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
548 original_size = x.shape[1] * x.shape[2]
549 target_size = self.height * self.width
551 x = F.applyfft2_np(x, axis=self.axis)
552 x = F.padifneeded(x, self.height, self.width)
553 x = F.center_crop(x, self.height, self.width)
554 x = F.applyifft2_np(x, axis=self.axis)
556 if self.scale:
557 return x * target_size / original_size
558 return x.astype(self.np_dtype)
560 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
561 original_size = x.shape[1] * x.shape[2]
562 target_size = self.height * self.width
564 x = F.applyfft2_torch(x, dim=self.axis)
565 x = F.padifneeded(x, self.height, self.width)
566 x = F.center_crop(x, self.height, self.width)
567 x = F.applyifft2_torch(x, dim=self.axis)
569 if self.scale:
570 return x * target_size / original_size
571 return x.to(self.torch_dtype)
574class SpatialResize:
575 """
576 Resize a complex tensor to a given size. The resize is performed in the image space
577 using a Bicubic interpolation.
579 Arguments:
580 size: The target size of the resized tensor.
581 """
583 def __init__(self, size):
584 self.size = size
586 def __call__(
587 self, array: Union[np.ndarray, torch.Tensor]
588 ) -> Union[np.ndarray, torch.Tensor]:
589 is_torch = False
590 if isinstance(array, torch.Tensor):
591 is_torch = True
592 array = array.numpy()
594 real_part = array.real
595 imaginary_part = array.imag
597 def zoom(array):
598 # Convert the numpy array to a PIL image
599 image = Image.fromarray(array)
601 # Resize the image
602 image = image.resize((self.size[1], self.size[0]))
604 # Convert the PIL image back to a numpy array
605 array = np.array(image)
607 return array
609 if len(array.shape) == 2:
610 # We have a two dimensional tensor
611 resized_real = zoom(real_part)
612 resized_imaginary = zoom(imaginary_part)
613 else:
614 # We have three dimensions and therefore
615 # apply the resize to each channel iteratively
616 # We assume the first dimension is the channel
617 resized_real = []
618 resized_imaginary = []
619 for real, imaginary in zip(real_part, imaginary_part):
620 resized_real.append(zoom(real))
621 resized_imaginary.append(zoom(imaginary))
622 resized_real = np.stack(resized_real)
623 resized_imaginary = np.stack(resized_imaginary)
625 resized_array = resized_real + 1j * resized_imaginary
627 # Convert the resized tensor back to a torch tensor if necessary
628 if is_torch:
629 resized_array = torch.as_tensor(resized_array)
631 return resized_array
634class PolSAR(BaseTransform):
635 """Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions.
636 This class provides functionality to convert between different channel representations of PolSAR data,
637 supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors.
638 If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array.
640 Args:
641 out_channel (int): Desired number of output channels (1, 2, 3, or 4)
643 Supported conversions:
644 - 1 channel -> 1 channel: Identity
645 - 2 channels -> 1 or 2 channels
646 - 3 channels -> 1, 2, or 3 channels where:
647 - 1 channel: Returns first channel only
648 - 2 channels: Returns [HH, VV] channels
649 - 3 channels: Returns all channels [HH, HV, VV]
650 - 4 channels -> 1, 2, 3, or 4 channels where:
651 - 1 channel: Returns first channel only
652 - 2 channels: Returns [HH, VV] channels
653 - 3 channels: Returns [HH, (HV+VH)/2, VV]
654 - 4 channels: Returns all channels [HH, HV, VH, VV]
656 Raises:
657 ValueError: If the requested channel conversion is invalid or not supported
659 Example:
660 >>> transform = PolSAR(out_channel=3)
661 >>> # For 4-channel input [HH, HV, VH, VV]
662 >>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV]
664 Note:
665 - Input data should have format Channels x Height x Width (CHW).
666 - By default, PolSAR always return HH polarization if out_channel is 1.
667 """
669 def __init__(self, out_channel: int) -> None:
670 self.out_channel = out_channel
672 def _handle_single_channel(
673 self, x: np.ndarray | torch.Tensor, out_channels: int
674 ) -> np.ndarray | torch.Tensor:
675 return x if out_channels == 1 else None
677 def _handle_two_channels(
678 self, x: np.ndarray | torch.Tensor, out_channels: int
679 ) -> np.ndarray | torch.Tensor:
680 if out_channels == 2:
681 return x
682 elif out_channels == 1:
683 return x[0:1]
684 return None
686 def _handle_three_channels(
687 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType
688 ) -> np.ndarray | torch.Tensor:
689 channel_maps = {
690 1: lambda: x[0:1],
691 2: lambda: backend.stack((x[0], x[2])),
692 3: lambda: backend.stack((x[0], x[1], x[2])),
693 }
694 return channel_maps.get(out_channels, lambda: None)()
697 def _handle_four_channels(
698 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType
699 ) -> np.ndarray | torch.Tensor:
700 channel_maps = {
701 1: lambda: x[0:1],
702 2: lambda: backend.stack((x[0], x[3])),
703 3: lambda: backend.stack((x[0], 0.5 * (x[1] + x[2]), x[3])),
704 4: lambda: x,
705 }
706 return channel_maps.get(out_channels, lambda: None)()
708 def _convert_channels(
709 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType
710 ) -> np.ndarray | torch.Tensor:
711 handlers = {
712 1: self._handle_single_channel,
713 2: self._handle_two_channels,
714 3: lambda x, o: self._handle_three_channels(x, o, backend),
715 4: lambda x, o: self._handle_four_channels(x, o, backend),
716 }
717 result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels)
718 if result is None:
719 raise ValueError(
720 f"Invalid conversion: {x.shape[0]} -> {out_channels} channels"
721 )
722 return result
724 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
725 return self._convert_channels(x, self.out_channel, np)
727 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
728 return self._convert_channels(x, self.out_channel, torch)
730 def __call__(
731 self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]
732 ) -> np.ndarray | torch.Tensor:
733 x = F.polsar_dict_to_array(x)
734 return super().__call__(x)
737class Unsqueeze(BaseTransform):
738 """Add a singleton dimension to the input array/tensor.
740 This transform inserts a new axis at the specified position, increasing
741 the dimensionality of the input by one.
743 Args:
744 dim (int): Position where new axis should be inserted.
746 Returns:
747 np.ndarray | torch.Tensor: Input with new singleton dimension added.
748 Shape will be same as input but with a 1 inserted at position dim.
750 Example:
751 >>> transform = Unsqueeze(dim=0)
752 >>> x = torch.randn(3,4) # Shape (3,4)
753 >>> y = transform(x) # Shape (1,3,4)
754 """
756 def __init__(self, dim: int) -> None:
757 self.dim = dim
759 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
760 return np.expand_dims(x, axis=self.dim)
762 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
763 return x.unsqueeze(dim=self.dim)
766class ToTensor(BaseTransform):
767 """Converts numpy array or torch tensor to torch tensor of specified dtype.
768 This transform converts input data to a PyTorch tensor with the specified data type.
769 It handles both numpy arrays and existing PyTorch tensors as input.
771 Args:
772 dtype (str): Target data type for the output tensor. Should be one of PyTorch's
773 supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.)
775 Returns:
776 torch.Tensor: The converted tensor with the specified dtype.
778 Example:
779 >>> transform = ToTensor(dtype='float32')
780 >>> x_numpy = np.array([1, 2, 3])
781 >>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor
782 >>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32)
783 >>> x_converted = transform(x_existing) # converts to torch.FloatTensor
784 """
786 def __init__(self, dtype: str) -> None:
787 super().__init__(dtype)
789 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
790 return torch.as_tensor(x, dtype=self.torch_dtype)
792 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
793 return x.to(self.torch_dtype)