Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/transforms/transforms.py: 62%
204 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from abc import ABC, abstractmethod
25from typing import Tuple, Union, Optional, Dict
26from types import NoneType, ModuleType
28# External imports
29import torch
30import numpy as np
31from PIL import Image
33# Internal imports
34import torchcvnn.transforms.functional as F
37class BaseTransform(ABC):
38 """Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors.
39 This class serves as a template for implementing transforms that can be applied to both numpy arrays
40 and PyTorch tensors while maintaining consistent behavior.
41 Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width),
42 they will be converted to (1, Height, Width).
44 Args:
45 dtype (str, optional): Data type to convert inputs to. Must be one of:
46 'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed.
47 Default: None.
49 Raises:
50 AssertionError: If dtype is not a string or not one of the allowed types.
51 ValueError: If input is neither a numpy array nor a PyTorch tensor.
53 Methods:
54 __call__(x): Apply the transform to the input array/tensor.
55 __call_numpy__(x): Abstract method to implement numpy array transform.
56 __call_torch__(x): Abstract method to implement PyTorch tensor transform.
58 Example:
59 >>> class MyTransform(BaseTransform):
60 >>> def __call_numpy__(self, x):
61 >>> # Implement numpy transform
62 >>> pass
63 >>> def __call_torch__(self, x):
64 >>> # Implement torch transform
65 >>> pass
66 >>> transform = MyTransform(dtype='float32')
67 >>> output = transform(input_data) # Works with both numpy arrays and torch tensors
68 """
69 def __init__(self, dtype: str | NoneType = None) -> None:
70 if dtype is not None:
71 assert isinstance(dtype, str), "dtype should be a string"
72 assert dtype in ["float32", "float64", "complex64", "complex128"], "dtype should be one of float32, float64, complex64, complex128"
73 self.np_dtype = getattr(np, dtype)
74 self.torch_dtype = getattr(torch, dtype)
76 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
77 """Apply transform to input."""
78 x = F.check_input(x)
79 if isinstance(x, np.ndarray):
80 return self.__call_numpy__(x)
81 elif isinstance(x, torch.Tensor):
82 return self.__call_torch__(x)
84 @abstractmethod
85 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
86 """Apply transform to numpy array."""
87 raise NotImplementedError
89 @abstractmethod
90 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
91 """Apply transform to torch tensor."""
92 raise NotImplementedError
95class LogAmplitude(BaseTransform):
96 """This transform applies a logarithmic scaling to the amplitude/magnitude of complex values
97 while optionally preserving the phase information. The amplitude is first clipped to
98 [min_value, max_value] range, then log10-transformed and normalized to [0,1] range.
100 The transformation follows these steps:
101 1. Extract amplitude and phase from complex input
102 2. Clip amplitude between min_value and max_value
103 3. Apply log10 transform and normalize to [0,1]
104 4. Optionally recombine with original phase
106 Args:
107 min_value (int | float, optional): Minimum amplitude value for clipping.
108 Values below this will be clipped up. Defaults to 0.02.
109 max_value (int | float, optional): Maximum amplitude value for clipping.
110 Values above this will be clipped down. Defaults to 40.
111 keep_phase (bool, optional): Whether to preserve phase information.
112 If True, returns complex output with transformed amplitude and original phase.
113 If False, returns just the transformed amplitude. Defaults to True.
114 Returns:
115 np.ndarray | torch.Tensor: Transformed tensor with same shape as input.
116 If keep_phase=True: Complex tensor with log-scaled amplitude and original phase
117 If keep_phase=False: Real tensor with just the log-scaled amplitude
118 Example:
119 >>> transform = LogAmplitude(min_value=0.01, max_value=100)
120 >>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1]
121 Note:
122 The transform works with both NumPy arrays and PyTorch tensors through
123 separate internal implementations (__call_numpy__ and __call_torch__).
124 """
125 def __init__(self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True) -> None:
126 self.min_value = min_value
127 self.max_value = max_value
128 self.keep_phase = keep_phase
130 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
131 return F.log_normalize_amplitude(x, np, self.keep_phase, self.min_value, self.max_value)
133 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
134 return F.log_normalize_amplitude(x, torch, self.keep_phase, self.min_value, self.max_value)
137class Amplitude(BaseTransform):
138 """Transform a complex-valued tensor into its amplitude/magnitude.
140 This transform computes the absolute value (magnitude) of complex input data,
141 converting complex values to real values.
143 Args:
144 dtype (str): Data type for the output ('float32', 'float64', etc)
146 Returns:
147 np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes,
148 with same shape as input but real-valued type specified by dtype.
149 """
150 def __init__(self, dtype: str) -> None:
151 super().__init__(dtype)
153 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
154 return torch.abs(x).to(self.torch_dtype)
156 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
157 return np.abs(x).astype(self.np_dtype)
160class RealImaginary(BaseTransform):
161 """Transform a complex-valued tensor into its real and imaginary components.
163 This transform separates a complex-valued tensor into its real and imaginary parts,
164 stacking them along a new channel dimension. The output tensor has twice the number
165 of channels as the input.
167 Returns:
168 np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts,
169 with shape (2*C, H, W) where C is the original number of channels.
170 """
171 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
172 x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW
173 x = x.flatten(0, 1) # 2CHW -> 2C*H*W
174 return x
176 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
177 x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW
178 x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W
179 return x
182class RandomPhase(BaseTransform):
183 """Randomly phase-shifts complex-valued input data.
184 This transform applies a random phase shift to complex-valued input tensors/arrays by
185 multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π]
186 or [-π, π] if centering is enabled.
187 Args:
188 dtype : str
189 Data type for the output. Must be one of the supported complex dtypes.
190 centering : bool, optional.
191 If True, centers the random phase distribution around 0 by subtracting π from
192 the generated phases. Default is False.
193 Returns
194 torch.Tensor or numpy.ndarray
195 Phase-shifted complex-valued data with the same shape as input.
197 Examples
198 >>> transform = RandomPhase(dtype='complex64')
199 >>> x = torch.ones(3,3, dtype=torch.complex64)
200 >>> output = transform(x) # Applies random phase shifts
202 Notes
203 - Input data must be complex-valued
204 - The output maintains the same shape and complex dtype as input
205 - Phase shifts are uniformly distributed in:
206 - [0, 2π] when centering=False
207 - [-π, π] when centering=True
208 """
209 def __init__(self, dtype: str, centering: bool = False) -> None:
210 super().__init__(dtype)
211 self.centering = centering
213 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
214 phase = torch.rand_like(x) * 2 * torch.pi
215 if self.centering:
216 phase = phase - torch.pi
217 return (x * torch.exp(1j * phase)).to(self.torch_dtype)
219 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
220 phase = np.random.rand(*x.shape) * 2 * np.pi
221 if self.centering:
222 phase = phase - np.pi
223 return (x * np.exp(1j * phase)).astype(self.np_dtype)
226class ToReal:
227 """Extracts the real part of a complex-valued input tensor.
229 The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers
230 and returns only their real parts. If the input is already real-valued, it remains unchanged.
232 Returns:
233 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
234 the real components of each element.
236 Example:
237 >>> to_real = ToReal()
238 >>> output = to_real(complex_tensor)
239 """
240 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
241 return x.real
244class ToImaginary:
245 """Extracts the imaginary part of a complex-valued input tensor.
247 The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers
248 and returns only their imaginary parts. If the input is already real-valued, it remains unchanged.
250 Returns:
251 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only
252 the imaginary components of each element.
254 Example:
255 >>> to_imaginary = ToImaginary()
256 >>> output = to_imaginary(complex_tensor)
257 """
258 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
259 return x.imag
262class FFT2(BaseTransform):
263 """Applies 2D Fast Fourier Transform (FFT) to the input.
264 This transform computes the 2D FFT along specified dimensions of the input array/tensor.
265 It applies FFT2 and shifts zero-frequency components to the center.
267 Args
268 axis : Tuple[int, ...], optional
269 The axes over which to compute the FFT. Default is (-2, -1).
271 Returns
272 numpy.ndarray or torch.Tensor
273 The 2D Fourier transformed input with zero-frequency components centered.
274 Output has the same shape as input.
276 Notes
277 - Transform is applied along specified dimensions (`axis`).
278 """
279 def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
280 self.axis = axis
282 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
283 return F.applyfft2_np(x, axis=self.axis)
285 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
286 return F.applyfft2_torch(x, dim=self.axis)
289class IFFT2(BaseTransform):
290 """Applies 2D inverse Fast Fourier Transform (IFFT) to the input.
291 This transform computes the 2D IFFT along the last two dimensions of the input array/tensor.
292 It applies inverse FFT shift before IFFT2.
294 Args
295 axis : Tuple[int, ...], optional
296 The axes over which to compute the FFT. Default is (-2, -1).
298 Returns
299 numpy.ndarray or torch.Tensor:
300 The inverse Fourier transformed input.
301 Output has the same shape as input.
303 Notes:
304 - Transform is applied along specified dimensions (`axis`).
305 """
306 def __init__(self, axis: Tuple[int, ...] = (-2, -1)):
307 self.axis = axis
309 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
310 return F.applyifft2_np(x, axis=self.axis)
312 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
313 return F.applyifft2_torch(x, dim=self.axis)
316class PadIfNeeded(BaseTransform):
317 """Pad an image if its dimensions are smaller than specified minimum dimensions.
319 This transform pads images that are smaller than given minimum dimensions by adding
320 padding according to the specified border mode. The padding is added symmetrically
321 on both sides to reach the minimum dimensions when possible. If the minimum required
322 dimension (height or width) is uneven, the right and the bottom sides will receive
323 an extra padding of 1 compared to the left and the top sides.
325 Args:
326 min_height (int): Minimum height requirement for the image
327 min_width (int): Minimum width requirement for the image
328 border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'.
329 pad_value (float): Value for constant padding (if applicable). Default is 0.
331 Returns:
332 np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width.
333 Original image if no padding is required.
335 Example:
336 >>> transform = PadIfNeeded(min_height=256, min_width=256)
337 >>> padded_image = transform(small_image) # Pads if image is smaller than 256x256
338 """
339 def __init__(
340 self,
341 min_height: int,
342 min_width: int,
343 border_mode: str = "constant",
344 pad_value: float = 0
345 ) -> None:
346 self.min_height = min_height
347 self.min_width = min_width
348 self.border_mode = border_mode
349 self.pad_value = pad_value
351 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
352 return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value)
354 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
355 return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value)
358class CenterCrop(BaseTransform):
359 """Center crops an input array/tensor to the specified size.
361 This transform extracts a centered rectangular region from the input array/tensor
362 with the specified dimensions. The crop is centered on both height and width axes.
364 Args:
365 height (int): Target height of the cropped output
366 width (int): Target width of the cropped output
368 Returns:
369 np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width),
370 where C is the number of channels
372 Examples:
373 >>> transform = CenterCrop(height=224, width=224)
374 >>> output = transform(input_tensor) # Center crops to 224x224
376 Notes:
377 - If input is smaller than crop size, it will return the original input
378 - Crop is applied identically to all channels
379 - Uses functional.center_crop() implementation for both numpy and torch
380 """
381 def __init__(self, height: int, width: int) -> None:
382 self.height = height
383 self.width = width
385 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
386 return F.center_crop(x, self.height, self.width)
388 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
389 return F.center_crop(x, self.height, self.width)
392class FFTResize(BaseTransform):
393 """Resizes an input image in spectral domain with Fourier Transformations.
395 This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes,
396 followed by padding or center cropping to achieve the target size, then applies
397 an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency
398 between original and resized images.
400 Args:
401 size: Tuple[int, int]
402 Target dimensions (height, width) for resizing.
403 axis: Tuple[int, ...], optional
404 The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW,
405 it corresponds to the Height and Width axes.
406 scale: bool, optional
407 If True, scales the output amplitudes to maintain energy consistency with
408 respect to input size. Default is False.
409 dtype: torch.dtype or numpy.dtype, optional
410 Output data type. If None, maintains the input data type.
411 For PyTorch tensors: torch.complex64 or torch.complex128
412 For NumPy arrays: numpy.complex64 or numpy.complex128
414 Returns:
415 numpy.ndarray or torch.Tensor
416 Resized image as a complex-valued array/tensor, maintaining shape (C, height, width).
418 Examples:
419 >>> transform = FFTResize((128, 128))
420 >>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT
422 Notes:
423 - Input must be a multi-dimensional array/tensor of shape Channel x Height x Width.
424 - Spectral domain resizing preserves frequency characteristics better than spatial interpolation
425 - Operates on complex-valued data, preserving phase information
426 - Memory efficient for large downsampling ratios
427 - Based on the Fourier Transform properties of scaling and periodicity
428 - The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data,
429 it is recommended to call ToReal after applying this transform.
430 """
431 def __init__(
432 self,
433 size: Tuple[int, ...],
434 axis: Tuple[int, ...] = (-2, -1),
435 scale: bool = False,
436 dtype: Optional[str] = "complex64"
437 ) -> None:
438 if dtype is None or "complex" not in str(dtype):
439 dtype = "complex64"
441 super().__init__(dtype)
442 assert isinstance(size, Tuple), "size must be a tuple"
443 assert isinstance(axis, Tuple), "axis must be a tuple"
444 self.height = size[0]
445 self.width = size[1]
446 self.axis = axis
447 self.scale = scale
449 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
450 original_size = x.shape[1] * x.shape[2]
451 target_size = self.height * self.width
453 x = F.applyfft2_np(x, axis=self.axis)
454 x = F.padifneeded(x, self.height, self.width)
455 x = F.center_crop(x, self.height, self.width)
456 x = F.applyifft2_np(x, axis=self.axis)
458 if self.scale:
459 return x * target_size / original_size
460 return x.astype(self.np_dtype)
462 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
463 original_size = x.shape[1] * x.shape[2]
464 target_size = self.height * self.width
466 x = F.applyfft2_torch(x, dim=self.axis)
467 x = F.padifneeded(x, self.height, self.width)
468 x = F.center_crop(x, self.height, self.width)
469 x = F.applyifft2_torch(x, dim=self.axis)
471 if self.scale:
472 return x * target_size / original_size
473 return x.to(self.torch_dtype)
476class SpatialResize:
477 """
478 Resize a complex tensor to a given size. The resize is performed in the image space
479 using a Bicubic interpolation.
481 Arguments:
482 size: The target size of the resized tensor.
483 """
485 def __init__(self, size):
486 self.size = size
488 def __call__(
489 self, array: Union[np.ndarray, torch.Tensor]
490 ) -> Union[np.ndarray, torch.Tensor]:
492 is_torch = False
493 if isinstance(array, torch.Tensor):
494 is_torch = True
495 array = array.numpy()
497 real_part = array.real
498 imaginary_part = array.imag
500 def zoom(array):
501 # Convert the numpy array to a PIL image
502 image = Image.fromarray(array)
504 # Resize the image
505 image = image.resize((self.size[1], self.size[0]))
507 # Convert the PIL image back to a numpy array
508 array = np.array(image)
510 return array
512 if len(array.shape) == 2:
513 # We have a two dimensional tensor
514 resized_real = zoom(real_part)
515 resized_imaginary = zoom(imaginary_part)
516 else:
517 # We have three dimensions and therefore
518 # apply the resize to each channel iteratively
519 # We assume the first dimension is the channel
520 resized_real = []
521 resized_imaginary = []
522 for real, imaginary in zip(real_part, imaginary_part):
523 resized_real.append(zoom(real))
524 resized_imaginary.append(zoom(imaginary))
525 resized_real = np.stack(resized_real)
526 resized_imaginary = np.stack(resized_imaginary)
528 resized_array = resized_real + 1j * resized_imaginary
530 # Convert the resized tensor back to a torch tensor if necessary
531 if is_torch:
532 resized_array = torch.as_tensor(resized_array)
534 return resized_array
537class PolSAR(BaseTransform):
538 """Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions.
539 This class provides functionality to convert between different channel representations of PolSAR data,
540 supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors.
541 If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array.
543 Args:
544 out_channel (int): Desired number of output channels (1, 2, 3, or 4)
546 Supported conversions:
547 - 1 channel -> 1 channel: Identity
548 - 2 channels -> 1 or 2 channels
549 - 4 channels -> 1, 2, 3, or 4 channels where:
550 - 1 channel: Returns first channel only
551 - 2 channels: Returns [HH, VV] channels
552 - 3 channels: Returns [HH, (HV+VH)/2, VV]
553 - 4 channels: Returns all channels [HH, HV, VH, VV]
555 Raises:
556 ValueError: If the requested channel conversion is invalid or not supported
558 Example:
559 >>> transform = PolSAR(out_channel=3)
560 >>> # For 4-channel input [HH, HV, VH, VV]
561 >>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV]
563 Note:
564 - Input data should have format Channels x Height x Width (CHW).
565 - By default, PolSAR always return HH polarization if out_channel is 1.
566 """
567 def __init__(self, out_channel: int) -> None:
568 self.out_channel = out_channel
570 def _handle_single_channel(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor:
571 return x if out_channels == 1 else None
573 def _handle_two_channels(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor:
574 if out_channels == 2:
575 return x
576 elif out_channels == 1:
577 return x[0:1]
578 return None
580 def _handle_four_channels(
581 self,
582 x: np.ndarray | torch.Tensor,
583 out_channels: int,
584 backend: ModuleType
585 ) -> np.ndarray | torch.Tensor:
586 channel_maps = {
587 1: lambda: x[0:1],
588 2: lambda: backend.stack((x[0], x[3])),
589 3: lambda: backend.stack((
590 x[0],
591 0.5 * (x[1] + x[2]),
592 x[3]
593 )),
594 4: lambda: x
595 }
596 return channel_maps.get(out_channels, lambda: None)()
598 def _convert_channels(
599 self,
600 x: np.ndarray | torch.Tensor,
601 out_channels: int,
602 backend: ModuleType
603 ) -> np.ndarray | torch.Tensor:
604 handlers = {
605 1: self._handle_single_channel,
606 2: self._handle_two_channels,
607 4: lambda x, o: self._handle_four_channels(x, o, backend)
608 }
609 result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels)
610 if result is None:
611 raise ValueError(f"Invalid conversion: {x.shape[0]} -> {out_channels} channels")
612 return result
614 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
615 return self._convert_channels(x, self.out_channel, np)
617 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
618 return self._convert_channels(x, self.out_channel, torch)
620 def __call__(self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]) -> np.ndarray | torch.Tensor:
621 x = F.polsar_dict_to_array(x)
622 return super().__call__(x)
625class Unsqueeze(BaseTransform):
626 """Add a singleton dimension to the input array/tensor.
628 This transform inserts a new axis at the specified position, increasing
629 the dimensionality of the input by one.
631 Args:
632 dim (int): Position where new axis should be inserted.
634 Returns:
635 np.ndarray | torch.Tensor: Input with new singleton dimension added.
636 Shape will be same as input but with a 1 inserted at position dim.
638 Example:
639 >>> transform = Unsqueeze(dim=0)
640 >>> x = torch.randn(3,4) # Shape (3,4)
641 >>> y = transform(x) # Shape (1,3,4)
642 """
643 def __init__(self, dim: int) -> None:
644 self.dim = dim
646 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
647 return np.expand_dims(x, axis=self.dim)
649 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
650 return x.unsqueeze(dim=self.dim)
653class ToTensor(BaseTransform):
654 """Converts numpy array or torch tensor to torch tensor of specified dtype.
655 This transform converts input data to a PyTorch tensor with the specified data type.
656 It handles both numpy arrays and existing PyTorch tensors as input.
658 Args:
659 dtype (str): Target data type for the output tensor. Should be one of PyTorch's
660 supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.)
662 Returns:
663 torch.Tensor: The converted tensor with the specified dtype.
665 Example:
666 >>> transform = ToTensor(dtype='float32')
667 >>> x_numpy = np.array([1, 2, 3])
668 >>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor
669 >>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32)
670 >>> x_converted = transform(x_existing) # converts to torch.FloatTensor
671 """
672 def __init__(self, dtype: str) -> None:
673 super().__init__(dtype)
675 def __call_numpy__(self, x: np.ndarray) -> np.ndarray:
676 return torch.as_tensor(x, dtype=self.torch_dtype)
678 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor:
679 return x.to(self.torch_dtype)