Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/transforms/transforms.py: 62%

204 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from abc import ABC, abstractmethod 

25from typing import Tuple, Union, Optional, Dict 

26from types import NoneType, ModuleType 

27 

28# External imports 

29import torch 

30import numpy as np 

31from PIL import Image 

32 

33# Internal imports 

34import torchcvnn.transforms.functional as F 

35 

36 

37class BaseTransform(ABC): 

38 """Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors. 

39 This class serves as a template for implementing transforms that can be applied to both numpy arrays 

40 and PyTorch tensors while maintaining consistent behavior.  

41 Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width), 

42 they will be converted to (1, Height, Width). 

43  

44 Args: 

45 dtype (str, optional): Data type to convert inputs to. Must be one of: 

46 'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed. 

47 Default: None. 

48  

49 Raises: 

50 AssertionError: If dtype is not a string or not one of the allowed types. 

51 ValueError: If input is neither a numpy array nor a PyTorch tensor. 

52  

53 Methods: 

54 __call__(x): Apply the transform to the input array/tensor. 

55 __call_numpy__(x): Abstract method to implement numpy array transform. 

56 __call_torch__(x): Abstract method to implement PyTorch tensor transform. 

57  

58 Example: 

59 >>> class MyTransform(BaseTransform): 

60 >>> def __call_numpy__(self, x): 

61 >>> # Implement numpy transform 

62 >>> pass 

63 >>> def __call_torch__(self, x): 

64 >>> # Implement torch transform 

65 >>> pass 

66 >>> transform = MyTransform(dtype='float32') 

67 >>> output = transform(input_data) # Works with both numpy arrays and torch tensors 

68 """ 

69 def __init__(self, dtype: str | NoneType = None) -> None: 

70 if dtype is not None: 

71 assert isinstance(dtype, str), "dtype should be a string" 

72 assert dtype in ["float32", "float64", "complex64", "complex128"], "dtype should be one of float32, float64, complex64, complex128" 

73 self.np_dtype = getattr(np, dtype) 

74 self.torch_dtype = getattr(torch, dtype) 

75 

76 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

77 """Apply transform to input.""" 

78 x = F.check_input(x) 

79 if isinstance(x, np.ndarray): 

80 return self.__call_numpy__(x) 

81 elif isinstance(x, torch.Tensor): 

82 return self.__call_torch__(x) 

83 

84 @abstractmethod 

85 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

86 """Apply transform to numpy array.""" 

87 raise NotImplementedError 

88 

89 @abstractmethod 

90 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

91 """Apply transform to torch tensor.""" 

92 raise NotImplementedError 

93 

94 

95class LogAmplitude(BaseTransform): 

96 """This transform applies a logarithmic scaling to the amplitude/magnitude of complex values 

97 while optionally preserving the phase information. The amplitude is first clipped to  

98 [min_value, max_value] range, then log10-transformed and normalized to [0,1] range. 

99 

100 The transformation follows these steps: 

101 1. Extract amplitude and phase from complex input 

102 2. Clip amplitude between min_value and max_value  

103 3. Apply log10 transform and normalize to [0,1] 

104 4. Optionally recombine with original phase 

105 

106 Args: 

107 min_value (int | float, optional): Minimum amplitude value for clipping.  

108 Values below this will be clipped up. Defaults to 0.02. 

109 max_value (int | float, optional): Maximum amplitude value for clipping. 

110 Values above this will be clipped down. Defaults to 40. 

111 keep_phase (bool, optional): Whether to preserve phase information. 

112 If True, returns complex output with transformed amplitude and original phase. 

113 If False, returns just the transformed amplitude. Defaults to True. 

114 Returns: 

115 np.ndarray | torch.Tensor: Transformed tensor with same shape as input. 

116 If keep_phase=True: Complex tensor with log-scaled amplitude and original phase 

117 If keep_phase=False: Real tensor with just the log-scaled amplitude 

118 Example: 

119 >>> transform = LogAmplitude(min_value=0.01, max_value=100) 

120 >>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1] 

121 Note: 

122 The transform works with both NumPy arrays and PyTorch tensors through 

123 separate internal implementations (__call_numpy__ and __call_torch__). 

124 """ 

125 def __init__(self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True) -> None: 

126 self.min_value = min_value 

127 self.max_value = max_value 

128 self.keep_phase = keep_phase 

129 

130 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

131 return F.log_normalize_amplitude(x, np, self.keep_phase, self.min_value, self.max_value) 

132 

133 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

134 return F.log_normalize_amplitude(x, torch, self.keep_phase, self.min_value, self.max_value) 

135 

136 

137class Amplitude(BaseTransform): 

138 """Transform a complex-valued tensor into its amplitude/magnitude. 

139 

140 This transform computes the absolute value (magnitude) of complex input data, 

141 converting complex values to real values. 

142 

143 Args: 

144 dtype (str): Data type for the output ('float32', 'float64', etc) 

145 

146 Returns: 

147 np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes, 

148 with same shape as input but real-valued type specified by dtype. 

149 """ 

150 def __init__(self, dtype: str) -> None: 

151 super().__init__(dtype) 

152 

153 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

154 return torch.abs(x).to(self.torch_dtype) 

155 

156 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

157 return np.abs(x).astype(self.np_dtype) 

158 

159 

160class RealImaginary(BaseTransform): 

161 """Transform a complex-valued tensor into its real and imaginary components. 

162 

163 This transform separates a complex-valued tensor into its real and imaginary parts, 

164 stacking them along a new channel dimension. The output tensor has twice the number 

165 of channels as the input. 

166 

167 Returns: 

168 np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts, 

169 with shape (2*C, H, W) where C is the original number of channels. 

170 """ 

171 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

172 x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW 

173 x = x.flatten(0, 1) # 2CHW -> 2C*H*W 

174 return x 

175 

176 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

177 x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW 

178 x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W 

179 return x 

180 

181 

182class RandomPhase(BaseTransform): 

183 """Randomly phase-shifts complex-valued input data. 

184 This transform applies a random phase shift to complex-valued input tensors/arrays by  

185 multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π]  

186 or [-π, π] if centering is enabled. 

187 Args: 

188 dtype : str 

189 Data type for the output. Must be one of the supported complex dtypes. 

190 centering : bool, optional.  

191 If True, centers the random phase distribution around 0 by subtracting π from  

192 the generated phases. Default is False. 

193 Returns 

194 torch.Tensor or numpy.ndarray 

195 Phase-shifted complex-valued data with the same shape as input. 

196 

197 Examples 

198 >>> transform = RandomPhase(dtype='complex64') 

199 >>> x = torch.ones(3,3, dtype=torch.complex64) 

200 >>> output = transform(x) # Applies random phase shifts 

201 

202 Notes 

203 - Input data must be complex-valued 

204 - The output maintains the same shape and complex dtype as input 

205 - Phase shifts are uniformly distributed in: 

206 - [0, 2π] when centering=False 

207 - [-π, π] when centering=True 

208 """ 

209 def __init__(self, dtype: str, centering: bool = False) -> None: 

210 super().__init__(dtype) 

211 self.centering = centering 

212 

213 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

214 phase = torch.rand_like(x) * 2 * torch.pi 

215 if self.centering: 

216 phase = phase - torch.pi 

217 return (x * torch.exp(1j * phase)).to(self.torch_dtype) 

218 

219 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

220 phase = np.random.rand(*x.shape) * 2 * np.pi 

221 if self.centering: 

222 phase = phase - np.pi 

223 return (x * np.exp(1j * phase)).astype(self.np_dtype) 

224 

225 

226class ToReal: 

227 """Extracts the real part of a complex-valued input tensor. 

228 

229 The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers  

230 and returns only their real parts. If the input is already real-valued, it remains unchanged. 

231 

232 Returns: 

233 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only  

234 the real components of each element. 

235  

236 Example: 

237 >>> to_real = ToReal() 

238 >>> output = to_real(complex_tensor) 

239 """ 

240 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

241 return x.real 

242 

243 

244class ToImaginary: 

245 """Extracts the imaginary part of a complex-valued input tensor. 

246 

247 The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers  

248 and returns only their imaginary parts. If the input is already real-valued, it remains unchanged. 

249 

250 Returns: 

251 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only  

252 the imaginary components of each element. 

253  

254 Example: 

255 >>> to_imaginary = ToImaginary() 

256 >>> output = to_imaginary(complex_tensor) 

257 """ 

258 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

259 return x.imag 

260 

261 

262class FFT2(BaseTransform): 

263 """Applies 2D Fast Fourier Transform (FFT) to the input. 

264 This transform computes the 2D FFT along specified dimensions of the input array/tensor. 

265 It applies FFT2 and shifts zero-frequency components to the center. 

266  

267 Args 

268 axis : Tuple[int, ...], optional 

269 The axes over which to compute the FFT. Default is (-2, -1). 

270  

271 Returns 

272 numpy.ndarray or torch.Tensor 

273 The 2D Fourier transformed input with zero-frequency components centered. 

274 Output has the same shape as input. 

275  

276 Notes 

277 - Transform is applied along specified dimensions (`axis`). 

278 """ 

279 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

280 self.axis = axis 

281 

282 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

283 return F.applyfft2_np(x, axis=self.axis) 

284 

285 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

286 return F.applyfft2_torch(x, dim=self.axis) 

287 

288 

289class IFFT2(BaseTransform): 

290 """Applies 2D inverse Fast Fourier Transform (IFFT) to the input. 

291 This transform computes the 2D IFFT along the last two dimensions of the input array/tensor. 

292 It applies inverse FFT shift before IFFT2. 

293  

294 Args 

295 axis : Tuple[int, ...], optional 

296 The axes over which to compute the FFT. Default is (-2, -1). 

297 

298 Returns 

299 numpy.ndarray or torch.Tensor:  

300 The inverse Fourier transformed input. 

301 Output has the same shape as input. 

302  

303 Notes: 

304 - Transform is applied along specified dimensions (`axis`). 

305 """ 

306 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

307 self.axis = axis 

308 

309 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

310 return F.applyifft2_np(x, axis=self.axis) 

311 

312 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

313 return F.applyifft2_torch(x, dim=self.axis) 

314 

315 

316class PadIfNeeded(BaseTransform): 

317 """Pad an image if its dimensions are smaller than specified minimum dimensions. 

318 

319 This transform pads images that are smaller than given minimum dimensions by adding 

320 padding according to the specified border mode. The padding is added symmetrically 

321 on both sides to reach the minimum dimensions when possible. If the minimum required  

322 dimension (height or width) is uneven, the right and the bottom sides will receive  

323 an extra padding of 1 compared to the left and the top sides. 

324 

325 Args: 

326 min_height (int): Minimum height requirement for the image 

327 min_width (int): Minimum width requirement for the image 

328 border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'. 

329 pad_value (float): Value for constant padding (if applicable). Default is 0. 

330 

331 Returns: 

332 np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width.  

333 Original image if no padding is required. 

334 

335 Example: 

336 >>> transform = PadIfNeeded(min_height=256, min_width=256) 

337 >>> padded_image = transform(small_image) # Pads if image is smaller than 256x256 

338 """ 

339 def __init__( 

340 self, 

341 min_height: int, 

342 min_width: int, 

343 border_mode: str = "constant", 

344 pad_value: float = 0 

345 ) -> None: 

346 self.min_height = min_height 

347 self.min_width = min_width 

348 self.border_mode = border_mode 

349 self.pad_value = pad_value 

350 

351 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

352 return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value) 

353 

354 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

355 return F.padifneeded(x, self.min_height, self.min_width, self.border_mode, self.pad_value) 

356 

357 

358class CenterCrop(BaseTransform): 

359 """Center crops an input array/tensor to the specified size. 

360 

361 This transform extracts a centered rectangular region from the input array/tensor 

362 with the specified dimensions. The crop is centered on both height and width axes. 

363 

364 Args: 

365 height (int): Target height of the cropped output 

366 width (int): Target width of the cropped output 

367 

368 Returns: 

369 np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width),  

370 where C is the number of channels 

371 

372 Examples: 

373 >>> transform = CenterCrop(height=224, width=224) 

374 >>> output = transform(input_tensor) # Center crops to 224x224 

375 

376 Notes: 

377 - If input is smaller than crop size, it will return the original input 

378 - Crop is applied identically to all channels 

379 - Uses functional.center_crop() implementation for both numpy and torch 

380 """ 

381 def __init__(self, height: int, width: int) -> None: 

382 self.height = height 

383 self.width = width 

384 

385 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

386 return F.center_crop(x, self.height, self.width) 

387 

388 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

389 return F.center_crop(x, self.height, self.width) 

390 

391 

392class FFTResize(BaseTransform): 

393 """Resizes an input image in spectral domain with Fourier Transformations. 

394 

395 This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes, 

396 followed by padding or center cropping to achieve the target size, then applies 

397 an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency  

398 between original and resized images. 

399 

400 Args: 

401 size: Tuple[int, int] 

402 Target dimensions (height, width) for resizing. 

403 axis: Tuple[int, ...], optional 

404 The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW, 

405 it corresponds to the Height and Width axes. 

406 scale: bool, optional 

407 If True, scales the output amplitudes to maintain energy consistency with  

408 respect to input size. Default is False. 

409 dtype: torch.dtype or numpy.dtype, optional 

410 Output data type. If None, maintains the input data type. 

411 For PyTorch tensors: torch.complex64 or torch.complex128 

412 For NumPy arrays: numpy.complex64 or numpy.complex128 

413 

414 Returns: 

415 numpy.ndarray or torch.Tensor 

416 Resized image as a complex-valued array/tensor, maintaining shape (C, height, width). 

417 

418 Examples: 

419 >>> transform = FFTResize((128, 128)) 

420 >>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT 

421 

422 Notes: 

423 - Input must be a multi-dimensional array/tensor of shape Channel x Height x Width. 

424 - Spectral domain resizing preserves frequency characteristics better than spatial interpolation 

425 - Operates on complex-valued data, preserving phase information 

426 - Memory efficient for large downsampling ratios 

427 - Based on the Fourier Transform properties of scaling and periodicity 

428 - The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data, 

429 it is recommended to call ToReal after applying this transform. 

430 """ 

431 def __init__( 

432 self, 

433 size: Tuple[int, ...], 

434 axis: Tuple[int, ...] = (-2, -1), 

435 scale: bool = False, 

436 dtype: Optional[str] = "complex64" 

437 ) -> None: 

438 if dtype is None or "complex" not in str(dtype): 

439 dtype = "complex64" 

440 

441 super().__init__(dtype) 

442 assert isinstance(size, Tuple), "size must be a tuple" 

443 assert isinstance(axis, Tuple), "axis must be a tuple" 

444 self.height = size[0] 

445 self.width = size[1] 

446 self.axis = axis 

447 self.scale = scale 

448 

449 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

450 original_size = x.shape[1] * x.shape[2] 

451 target_size = self.height * self.width 

452 

453 x = F.applyfft2_np(x, axis=self.axis) 

454 x = F.padifneeded(x, self.height, self.width) 

455 x = F.center_crop(x, self.height, self.width) 

456 x = F.applyifft2_np(x, axis=self.axis) 

457 

458 if self.scale: 

459 return x * target_size / original_size 

460 return x.astype(self.np_dtype) 

461 

462 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

463 original_size = x.shape[1] * x.shape[2] 

464 target_size = self.height * self.width 

465 

466 x = F.applyfft2_torch(x, dim=self.axis) 

467 x = F.padifneeded(x, self.height, self.width) 

468 x = F.center_crop(x, self.height, self.width) 

469 x = F.applyifft2_torch(x, dim=self.axis) 

470 

471 if self.scale: 

472 return x * target_size / original_size 

473 return x.to(self.torch_dtype) 

474 

475 

476class SpatialResize: 

477 """ 

478 Resize a complex tensor to a given size. The resize is performed in the image space 

479 using a Bicubic interpolation. 

480 

481 Arguments: 

482 size: The target size of the resized tensor. 

483 """ 

484 

485 def __init__(self, size): 

486 self.size = size 

487 

488 def __call__( 

489 self, array: Union[np.ndarray, torch.Tensor] 

490 ) -> Union[np.ndarray, torch.Tensor]: 

491 

492 is_torch = False 

493 if isinstance(array, torch.Tensor): 

494 is_torch = True 

495 array = array.numpy() 

496 

497 real_part = array.real 

498 imaginary_part = array.imag 

499 

500 def zoom(array): 

501 # Convert the numpy array to a PIL image 

502 image = Image.fromarray(array) 

503 

504 # Resize the image 

505 image = image.resize((self.size[1], self.size[0])) 

506 

507 # Convert the PIL image back to a numpy array 

508 array = np.array(image) 

509 

510 return array 

511 

512 if len(array.shape) == 2: 

513 # We have a two dimensional tensor 

514 resized_real = zoom(real_part) 

515 resized_imaginary = zoom(imaginary_part) 

516 else: 

517 # We have three dimensions and therefore 

518 # apply the resize to each channel iteratively 

519 # We assume the first dimension is the channel 

520 resized_real = [] 

521 resized_imaginary = [] 

522 for real, imaginary in zip(real_part, imaginary_part): 

523 resized_real.append(zoom(real)) 

524 resized_imaginary.append(zoom(imaginary)) 

525 resized_real = np.stack(resized_real) 

526 resized_imaginary = np.stack(resized_imaginary) 

527 

528 resized_array = resized_real + 1j * resized_imaginary 

529 

530 # Convert the resized tensor back to a torch tensor if necessary 

531 if is_torch: 

532 resized_array = torch.as_tensor(resized_array) 

533 

534 return resized_array 

535 

536 

537class PolSAR(BaseTransform): 

538 """Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions. 

539 This class provides functionality to convert between different channel representations of PolSAR data, 

540 supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors. 

541 If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array. 

542  

543 Args: 

544 out_channel (int): Desired number of output channels (1, 2, 3, or 4) 

545  

546 Supported conversions: 

547 - 1 channel -> 1 channel: Identity 

548 - 2 channels -> 1 or 2 channels 

549 - 4 channels -> 1, 2, 3, or 4 channels where: 

550 - 1 channel: Returns first channel only 

551 - 2 channels: Returns [HH, VV] channels 

552 - 3 channels: Returns [HH, (HV+VH)/2, VV] 

553 - 4 channels: Returns all channels [HH, HV, VH, VV] 

554  

555 Raises: 

556 ValueError: If the requested channel conversion is invalid or not supported 

557  

558 Example: 

559 >>> transform = PolSAR(out_channel=3) 

560 >>> # For 4-channel input [HH, HV, VH, VV] 

561 >>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV] 

562  

563 Note: 

564 - Input data should have format Channels x Height x Width (CHW). 

565 - By default, PolSAR always return HH polarization if out_channel is 1. 

566 """ 

567 def __init__(self, out_channel: int) -> None: 

568 self.out_channel = out_channel 

569 

570 def _handle_single_channel(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor: 

571 return x if out_channels == 1 else None 

572 

573 def _handle_two_channels(self, x: np.ndarray | torch.Tensor, out_channels: int) -> np.ndarray | torch.Tensor: 

574 if out_channels == 2: 

575 return x 

576 elif out_channels == 1: 

577 return x[0:1] 

578 return None 

579 

580 def _handle_four_channels( 

581 self, 

582 x: np.ndarray | torch.Tensor, 

583 out_channels: int, 

584 backend: ModuleType 

585 ) -> np.ndarray | torch.Tensor: 

586 channel_maps = { 

587 1: lambda: x[0:1], 

588 2: lambda: backend.stack((x[0], x[3])), 

589 3: lambda: backend.stack(( 

590 x[0], 

591 0.5 * (x[1] + x[2]), 

592 x[3] 

593 )), 

594 4: lambda: x 

595 } 

596 return channel_maps.get(out_channels, lambda: None)() 

597 

598 def _convert_channels( 

599 self, 

600 x: np.ndarray | torch.Tensor, 

601 out_channels: int, 

602 backend: ModuleType 

603 ) -> np.ndarray | torch.Tensor: 

604 handlers = { 

605 1: self._handle_single_channel, 

606 2: self._handle_two_channels, 

607 4: lambda x, o: self._handle_four_channels(x, o, backend) 

608 } 

609 result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels) 

610 if result is None: 

611 raise ValueError(f"Invalid conversion: {x.shape[0]} -> {out_channels} channels") 

612 return result 

613 

614 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

615 return self._convert_channels(x, self.out_channel, np) 

616 

617 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

618 return self._convert_channels(x, self.out_channel, torch) 

619 

620 def __call__(self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]) -> np.ndarray | torch.Tensor: 

621 x = F.polsar_dict_to_array(x) 

622 return super().__call__(x) 

623 

624 

625class Unsqueeze(BaseTransform): 

626 """Add a singleton dimension to the input array/tensor. 

627 

628 This transform inserts a new axis at the specified position, increasing  

629 the dimensionality of the input by one. 

630 

631 Args: 

632 dim (int): Position where new axis should be inserted. 

633 

634 Returns: 

635 np.ndarray | torch.Tensor: Input with new singleton dimension added. 

636 Shape will be same as input but with a 1 inserted at position dim. 

637 

638 Example: 

639 >>> transform = Unsqueeze(dim=0)  

640 >>> x = torch.randn(3,4) # Shape (3,4) 

641 >>> y = transform(x) # Shape (1,3,4) 

642 """ 

643 def __init__(self, dim: int) -> None: 

644 self.dim = dim 

645 

646 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

647 return np.expand_dims(x, axis=self.dim) 

648 

649 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

650 return x.unsqueeze(dim=self.dim) 

651 

652 

653class ToTensor(BaseTransform): 

654 """Converts numpy array or torch tensor to torch tensor of specified dtype. 

655 This transform converts input data to a PyTorch tensor with the specified data type. 

656 It handles both numpy arrays and existing PyTorch tensors as input. 

657 

658 Args: 

659 dtype (str): Target data type for the output tensor. Should be one of PyTorch's 

660 supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.) 

661 

662 Returns: 

663 torch.Tensor: The converted tensor with the specified dtype. 

664  

665 Example: 

666 >>> transform = ToTensor(dtype='float32') 

667 >>> x_numpy = np.array([1, 2, 3]) 

668 >>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor 

669 >>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32) 

670 >>> x_converted = transform(x_existing) # converts to torch.FloatTensor 

671 """ 

672 def __init__(self, dtype: str) -> None: 

673 super().__init__(dtype) 

674 

675 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

676 return torch.as_tensor(x, dtype=self.torch_dtype) 

677 

678 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

679 return x.to(self.torch_dtype)