Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/transforms/transforms.py: 62%

204 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 05:19 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from abc import ABC, abstractmethod 

25from typing import Tuple, Union, Optional, Dict 

26from types import ModuleType 

27 

28# External imports 

29import torch 

30import numpy as np 

31from PIL import Image 

32 

33# Internal imports 

34import torchcvnn.transforms.functional as F 

35 

36 

37class BaseTransform(ABC): 

38 """Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors. 

39 This class serves as a template for implementing transforms that can be applied to both numpy arrays 

40 and PyTorch tensors while maintaining consistent behavior. 

41 Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width), 

42 they will be converted to (1, Height, Width). 

43 

44 Args: 

45 dtype (str, optional): Data type to convert inputs to. Must be one of: 

46 'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed. 

47 Default: None. 

48 

49 Raises: 

50 AssertionError: If dtype is not a string or not one of the allowed types. 

51 ValueError: If input is neither a numpy array nor a PyTorch tensor. 

52 

53 Methods: 

54 __call__(x): Apply the transform to the input array/tensor. 

55 __call_numpy__(x): Abstract method to implement numpy array transform. 

56 __call_torch__(x): Abstract method to implement PyTorch tensor transform. 

57 

58 Example: 

59 >>> class MyTransform(BaseTransform): 

60 >>> def __call_numpy__(self, x): 

61 >>> # Implement numpy transform 

62 >>> pass 

63 >>> def __call_torch__(self, x): 

64 >>> # Implement torch transform 

65 >>> pass 

66 >>> transform = MyTransform(dtype='float32') 

67 >>> output = transform(input_data) # Works with both numpy arrays and torch tensors 

68 """ 

69 

70 def __init__(self, dtype: str = None) -> None: 

71 if dtype is not None: 

72 assert isinstance(dtype, str), "dtype should be a string" 

73 assert dtype in [ 

74 "float32", 

75 "float64", 

76 "complex64", 

77 "complex128", 

78 ], "dtype should be one of float32, float64, complex64, complex128" 

79 self.np_dtype = getattr(np, dtype) 

80 self.torch_dtype = getattr(torch, dtype) 

81 

82 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

83 """Apply transform to input.""" 

84 x = F.check_input(x) 

85 if isinstance(x, np.ndarray): 

86 return self.__call_numpy__(x) 

87 elif isinstance(x, torch.Tensor): 

88 return self.__call_torch__(x) 

89 

90 @abstractmethod 

91 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

92 """Apply transform to numpy array.""" 

93 raise NotImplementedError 

94 

95 @abstractmethod 

96 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

97 """Apply transform to torch tensor.""" 

98 raise NotImplementedError 

99 

100 

101class LogAmplitude(BaseTransform): 

102 """This transform applies a logarithmic scaling to the amplitude/magnitude of complex values 

103 while optionally preserving the phase information. The amplitude is first clipped to 

104 [min_value, max_value] range, then log10-transformed and normalized to [0,1] range. 

105 

106 The transformation follows these steps: 

107 1. Extract amplitude and phase from complex input 

108 2. Clip amplitude between min_value and max_value 

109 3. Apply log10 transform and normalize to [0,1] 

110 4. Optionally recombine with original phase 

111 

112 Args: 

113 min_value (int | float, optional): Minimum amplitude value for clipping. 

114 Values below this will be clipped up. Defaults to 0.02. 

115 max_value (int | float, optional): Maximum amplitude value for clipping. 

116 Values above this will be clipped down. Defaults to 40. 

117 keep_phase (bool, optional): Whether to preserve phase information. 

118 If True, returns complex output with transformed amplitude and original phase. 

119 If False, returns just the transformed amplitude. Defaults to True. 

120 Returns: 

121 np.ndarray | torch.Tensor: Transformed tensor with same shape as input. 

122 If keep_phase=True: Complex tensor with log-scaled amplitude and original phase 

123 If keep_phase=False: Real tensor with just the log-scaled amplitude 

124 Example: 

125 >>> transform = LogAmplitude(min_value=0.01, max_value=100) 

126 >>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1] 

127 Note: 

128 The transform works with both NumPy arrays and PyTorch tensors through 

129 separate internal implementations (__call_numpy__ and __call_torch__). 

130 """ 

131 

132 def __init__( 

133 self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True 

134 ) -> None: 

135 self.min_value = min_value 

136 self.max_value = max_value 

137 self.keep_phase = keep_phase 

138 

139 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

140 return F.log_normalize_amplitude( 

141 x, np, self.keep_phase, self.min_value, self.max_value 

142 ) 

143 

144 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

145 return F.log_normalize_amplitude( 

146 x, torch, self.keep_phase, self.min_value, self.max_value 

147 ) 

148 

149 

150class Amplitude(BaseTransform): 

151 """Transform a complex-valued tensor into its amplitude/magnitude. 

152 

153 This transform computes the absolute value (magnitude) of complex input data, 

154 converting complex values to real values. 

155 

156 Args: 

157 dtype (str): Data type for the output ('float32', 'float64', etc) 

158 

159 Returns: 

160 np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes, 

161 with same shape as input but real-valued type specified by dtype. 

162 """ 

163 

164 def __init__(self, dtype: str) -> None: 

165 super().__init__(dtype) 

166 

167 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

168 return torch.abs(x).to(self.torch_dtype) 

169 

170 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

171 return np.abs(x).astype(self.np_dtype) 

172 

173 

174class RealImaginary(BaseTransform): 

175 """Transform a complex-valued tensor into its real and imaginary components. 

176 

177 This transform separates a complex-valued tensor into its real and imaginary parts, 

178 stacking them along a new channel dimension. The output tensor has twice the number 

179 of channels as the input. 

180 

181 Returns: 

182 np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts, 

183 with shape (2*C, H, W) where C is the original number of channels. 

184 """ 

185 

186 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

187 x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW 

188 x = x.flatten(0, 1) # 2CHW -> 2C*H*W 

189 return x 

190 

191 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

192 x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW 

193 x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W 

194 return x 

195 

196 

197class RandomPhase(BaseTransform): 

198 """Randomly phase-shifts complex-valued input data. 

199 This transform applies a random phase shift to complex-valued input tensors/arrays by 

200 multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π] 

201 or [-π, π] if centering is enabled. 

202 Args: 

203 dtype : str 

204 Data type for the output. Must be one of the supported complex dtypes. 

205 centering : bool, optional. 

206 If True, centers the random phase distribution around 0 by subtracting π from 

207 the generated phases. Default is False. 

208 Returns 

209 torch.Tensor or numpy.ndarray 

210 Phase-shifted complex-valued data with the same shape as input. 

211 

212 Examples 

213 >>> transform = RandomPhase(dtype='complex64') 

214 >>> x = torch.ones(3,3, dtype=torch.complex64) 

215 >>> output = transform(x) # Applies random phase shifts 

216 

217 Notes 

218 - Input data must be complex-valued 

219 - The output maintains the same shape and complex dtype as input 

220 - Phase shifts are uniformly distributed in: 

221 - [0, 2π] when centering=False 

222 - [-π, π] when centering=True 

223 """ 

224 

225 def __init__(self, dtype: str, centering: bool = False) -> None: 

226 super().__init__(dtype) 

227 self.centering = centering 

228 

229 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

230 phase = torch.rand_like(x) * 2 * torch.pi 

231 if self.centering: 

232 phase = phase - torch.pi 

233 return (x * torch.exp(1j * phase)).to(self.torch_dtype) 

234 

235 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

236 phase = np.random.rand(*x.shape) * 2 * np.pi 

237 if self.centering: 

238 phase = phase - np.pi 

239 return (x * np.exp(1j * phase)).astype(self.np_dtype) 

240 

241 

242class ToReal: 

243 """Extracts the real part of a complex-valued input tensor. 

244 

245 The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers 

246 and returns only their real parts. If the input is already real-valued, it remains unchanged. 

247 

248 Returns: 

249 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only 

250 the real components of each element. 

251 

252 Example: 

253 >>> to_real = ToReal() 

254 >>> output = to_real(complex_tensor) 

255 """ 

256 

257 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

258 return x.real 

259 

260 

261class ToImaginary: 

262 """Extracts the imaginary part of a complex-valued input tensor. 

263 

264 The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers 

265 and returns only their imaginary parts. If the input is already real-valued, it remains unchanged. 

266 

267 Returns: 

268 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only 

269 the imaginary components of each element. 

270 

271 Example: 

272 >>> to_imaginary = ToImaginary() 

273 >>> output = to_imaginary(complex_tensor) 

274 """ 

275 

276 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

277 return x.imag 

278 

279 

280class FFT2(BaseTransform): 

281 """Applies 2D Fast Fourier Transform (FFT) to the input. 

282 This transform computes the 2D FFT along specified dimensions of the input array/tensor. 

283 It applies FFT2 and shifts zero-frequency components to the center. 

284 

285 Args 

286 axis : Tuple[int, ...], optional 

287 The axes over which to compute the FFT. Default is (-2, -1). 

288 

289 Returns 

290 numpy.ndarray or torch.Tensor 

291 The 2D Fourier transformed input with zero-frequency components centered. 

292 Output has the same shape as input. 

293 

294 Notes 

295 - Transform is applied along specified dimensions (`axis`). 

296 """ 

297 

298 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

299 self.axis = axis 

300 

301 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

302 return F.applyfft2_np(x, axis=self.axis) 

303 

304 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

305 return F.applyfft2_torch(x, dim=self.axis) 

306 

307 

308class IFFT2(BaseTransform): 

309 """Applies 2D inverse Fast Fourier Transform (IFFT) to the input. 

310 This transform computes the 2D IFFT along the last two dimensions of the input array/tensor. 

311 It applies inverse FFT shift before IFFT2. 

312 

313 Args 

314 axis : Tuple[int, ...], optional 

315 The axes over which to compute the FFT. Default is (-2, -1). 

316 

317 Returns 

318 numpy.ndarray or torch.Tensor: 

319 The inverse Fourier transformed input. 

320 Output has the same shape as input. 

321 

322 Notes: 

323 - Transform is applied along specified dimensions (`axis`). 

324 """ 

325 

326 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

327 self.axis = axis 

328 

329 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

330 return F.applyifft2_np(x, axis=self.axis) 

331 

332 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

333 return F.applyifft2_torch(x, dim=self.axis) 

334 

335 

336class PadIfNeeded(BaseTransform): 

337 """Pad an image if its dimensions are smaller than specified minimum dimensions. 

338 

339 This transform pads images that are smaller than given minimum dimensions by adding 

340 padding according to the specified border mode. The padding is added symmetrically 

341 on both sides to reach the minimum dimensions when possible. If the minimum required 

342 dimension (height or width) is uneven, the right and the bottom sides will receive 

343 an extra padding of 1 compared to the left and the top sides. 

344 

345 Args: 

346 min_height (int): Minimum height requirement for the image 

347 min_width (int): Minimum width requirement for the image 

348 border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'. 

349 pad_value (float): Value for constant padding (if applicable). Default is 0. 

350 

351 Returns: 

352 np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width. 

353 Original image if no padding is required. 

354 

355 Example: 

356 >>> transform = PadIfNeeded(min_height=256, min_width=256) 

357 >>> padded_image = transform(small_image) # Pads if image is smaller than 256x256 

358 """ 

359 

360 def __init__( 

361 self, 

362 min_height: int, 

363 min_width: int, 

364 border_mode: str = "constant", 

365 pad_value: float = 0, 

366 ) -> None: 

367 self.min_height = min_height 

368 self.min_width = min_width 

369 self.border_mode = border_mode 

370 self.pad_value = pad_value 

371 

372 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

373 return F.padifneeded( 

374 x, self.min_height, self.min_width, self.border_mode, self.pad_value 

375 ) 

376 

377 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

378 return F.padifneeded( 

379 x, self.min_height, self.min_width, self.border_mode, self.pad_value 

380 ) 

381 

382 

383class CenterCrop(BaseTransform): 

384 """Center crops an input array/tensor to the specified size. 

385 

386 This transform extracts a centered rectangular region from the input array/tensor 

387 with the specified dimensions. The crop is centered on both height and width axes. 

388 

389 Args: 

390 height (int): Target height of the cropped output 

391 width (int): Target width of the cropped output 

392 

393 Returns: 

394 np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width), 

395 where C is the number of channels 

396 

397 Examples: 

398 >>> transform = CenterCrop(height=224, width=224) 

399 >>> output = transform(input_tensor) # Center crops to 224x224 

400 

401 Notes: 

402 - If input is smaller than crop size, it will return the original input 

403 - Crop is applied identically to all channels 

404 - Uses functional.center_crop() implementation for both numpy and torch 

405 """ 

406 

407 def __init__(self, height: int, width: int) -> None: 

408 self.height = height 

409 self.width = width 

410 

411 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

412 return F.center_crop(x, self.height, self.width) 

413 

414 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

415 return F.center_crop(x, self.height, self.width) 

416 

417 

418class FFTResize(BaseTransform): 

419 """Resizes an input image in spectral domain with Fourier Transformations. 

420 

421 This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes, 

422 followed by padding or center cropping to achieve the target size, then applies 

423 an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency 

424 between original and resized images. 

425 

426 Args: 

427 size: Tuple[int, int] 

428 Target dimensions (height, width) for resizing. 

429 axis: Tuple[int, ...], optional 

430 The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW, 

431 it corresponds to the Height and Width axes. 

432 scale: bool, optional 

433 If True, scales the output amplitudes to maintain energy consistency with 

434 respect to input size. Default is False. 

435 dtype: torch.dtype or numpy.dtype, optional 

436 Output data type. If None, maintains the input data type. 

437 For PyTorch tensors: torch.complex64 or torch.complex128 

438 For NumPy arrays: numpy.complex64 or numpy.complex128 

439 

440 Returns: 

441 numpy.ndarray or torch.Tensor 

442 Resized image as a complex-valued array/tensor, maintaining shape (C, height, width). 

443 

444 Examples: 

445 >>> transform = FFTResize((128, 128)) 

446 >>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT 

447 

448 Notes: 

449 - Input must be a multi-dimensional array/tensor of shape Channel x Height x Width. 

450 - Spectral domain resizing preserves frequency characteristics better than spatial interpolation 

451 - Operates on complex-valued data, preserving phase information 

452 - Memory efficient for large downsampling ratios 

453 - Based on the Fourier Transform properties of scaling and periodicity 

454 - The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data, 

455 it is recommended to call ToReal after applying this transform. 

456 """ 

457 

458 def __init__( 

459 self, 

460 size: Tuple[int, ...], 

461 axis: Tuple[int, ...] = (-2, -1), 

462 scale: bool = False, 

463 dtype: Optional[str] = "complex64", 

464 ) -> None: 

465 if dtype is None or "complex" not in str(dtype): 

466 dtype = "complex64" 

467 

468 super().__init__(dtype) 

469 assert isinstance(size, Tuple), "size must be a tuple" 

470 assert isinstance(axis, Tuple), "axis must be a tuple" 

471 self.height = size[0] 

472 self.width = size[1] 

473 self.axis = axis 

474 self.scale = scale 

475 

476 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

477 original_size = x.shape[1] * x.shape[2] 

478 target_size = self.height * self.width 

479 

480 x = F.applyfft2_np(x, axis=self.axis) 

481 x = F.padifneeded(x, self.height, self.width) 

482 x = F.center_crop(x, self.height, self.width) 

483 x = F.applyifft2_np(x, axis=self.axis) 

484 

485 if self.scale: 

486 return x * target_size / original_size 

487 return x.astype(self.np_dtype) 

488 

489 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

490 original_size = x.shape[1] * x.shape[2] 

491 target_size = self.height * self.width 

492 

493 x = F.applyfft2_torch(x, dim=self.axis) 

494 x = F.padifneeded(x, self.height, self.width) 

495 x = F.center_crop(x, self.height, self.width) 

496 x = F.applyifft2_torch(x, dim=self.axis) 

497 

498 if self.scale: 

499 return x * target_size / original_size 

500 return x.to(self.torch_dtype) 

501 

502 

503class SpatialResize: 

504 """ 

505 Resize a complex tensor to a given size. The resize is performed in the image space 

506 using a Bicubic interpolation. 

507 

508 Arguments: 

509 size: The target size of the resized tensor. 

510 """ 

511 

512 def __init__(self, size): 

513 self.size = size 

514 

515 def __call__( 

516 self, array: Union[np.ndarray, torch.Tensor] 

517 ) -> Union[np.ndarray, torch.Tensor]: 

518 is_torch = False 

519 if isinstance(array, torch.Tensor): 

520 is_torch = True 

521 array = array.numpy() 

522 

523 real_part = array.real 

524 imaginary_part = array.imag 

525 

526 def zoom(array): 

527 # Convert the numpy array to a PIL image 

528 image = Image.fromarray(array) 

529 

530 # Resize the image 

531 image = image.resize((self.size[1], self.size[0])) 

532 

533 # Convert the PIL image back to a numpy array 

534 array = np.array(image) 

535 

536 return array 

537 

538 if len(array.shape) == 2: 

539 # We have a two dimensional tensor 

540 resized_real = zoom(real_part) 

541 resized_imaginary = zoom(imaginary_part) 

542 else: 

543 # We have three dimensions and therefore 

544 # apply the resize to each channel iteratively 

545 # We assume the first dimension is the channel 

546 resized_real = [] 

547 resized_imaginary = [] 

548 for real, imaginary in zip(real_part, imaginary_part): 

549 resized_real.append(zoom(real)) 

550 resized_imaginary.append(zoom(imaginary)) 

551 resized_real = np.stack(resized_real) 

552 resized_imaginary = np.stack(resized_imaginary) 

553 

554 resized_array = resized_real + 1j * resized_imaginary 

555 

556 # Convert the resized tensor back to a torch tensor if necessary 

557 if is_torch: 

558 resized_array = torch.as_tensor(resized_array) 

559 

560 return resized_array 

561 

562 

563class PolSAR(BaseTransform): 

564 """Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions. 

565 This class provides functionality to convert between different channel representations of PolSAR data, 

566 supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors. 

567 If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array. 

568 

569 Args: 

570 out_channel (int): Desired number of output channels (1, 2, 3, or 4) 

571 

572 Supported conversions: 

573 - 1 channel -> 1 channel: Identity 

574 - 2 channels -> 1 or 2 channels 

575 - 4 channels -> 1, 2, 3, or 4 channels where: 

576 - 1 channel: Returns first channel only 

577 - 2 channels: Returns [HH, VV] channels 

578 - 3 channels: Returns [HH, (HV+VH)/2, VV] 

579 - 4 channels: Returns all channels [HH, HV, VH, VV] 

580 

581 Raises: 

582 ValueError: If the requested channel conversion is invalid or not supported 

583 

584 Example: 

585 >>> transform = PolSAR(out_channel=3) 

586 >>> # For 4-channel input [HH, HV, VH, VV] 

587 >>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV] 

588 

589 Note: 

590 - Input data should have format Channels x Height x Width (CHW). 

591 - By default, PolSAR always return HH polarization if out_channel is 1. 

592 """ 

593 

594 def __init__(self, out_channel: int) -> None: 

595 self.out_channel = out_channel 

596 

597 def _handle_single_channel( 

598 self, x: np.ndarray | torch.Tensor, out_channels: int 

599 ) -> np.ndarray | torch.Tensor: 

600 return x if out_channels == 1 else None 

601 

602 def _handle_two_channels( 

603 self, x: np.ndarray | torch.Tensor, out_channels: int 

604 ) -> np.ndarray | torch.Tensor: 

605 if out_channels == 2: 

606 return x 

607 elif out_channels == 1: 

608 return x[0:1] 

609 return None 

610 

611 def _handle_four_channels( 

612 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType 

613 ) -> np.ndarray | torch.Tensor: 

614 channel_maps = { 

615 1: lambda: x[0:1], 

616 2: lambda: backend.stack((x[0], x[3])), 

617 3: lambda: backend.stack((x[0], 0.5 * (x[1] + x[2]), x[3])), 

618 4: lambda: x, 

619 } 

620 return channel_maps.get(out_channels, lambda: None)() 

621 

622 def _convert_channels( 

623 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType 

624 ) -> np.ndarray | torch.Tensor: 

625 handlers = { 

626 1: self._handle_single_channel, 

627 2: self._handle_two_channels, 

628 4: lambda x, o: self._handle_four_channels(x, o, backend), 

629 } 

630 result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels) 

631 if result is None: 

632 raise ValueError( 

633 f"Invalid conversion: {x.shape[0]} -> {out_channels} channels" 

634 ) 

635 return result 

636 

637 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

638 return self._convert_channels(x, self.out_channel, np) 

639 

640 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

641 return self._convert_channels(x, self.out_channel, torch) 

642 

643 def __call__( 

644 self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray] 

645 ) -> np.ndarray | torch.Tensor: 

646 x = F.polsar_dict_to_array(x) 

647 return super().__call__(x) 

648 

649 

650class Unsqueeze(BaseTransform): 

651 """Add a singleton dimension to the input array/tensor. 

652 

653 This transform inserts a new axis at the specified position, increasing 

654 the dimensionality of the input by one. 

655 

656 Args: 

657 dim (int): Position where new axis should be inserted. 

658 

659 Returns: 

660 np.ndarray | torch.Tensor: Input with new singleton dimension added. 

661 Shape will be same as input but with a 1 inserted at position dim. 

662 

663 Example: 

664 >>> transform = Unsqueeze(dim=0) 

665 >>> x = torch.randn(3,4) # Shape (3,4) 

666 >>> y = transform(x) # Shape (1,3,4) 

667 """ 

668 

669 def __init__(self, dim: int) -> None: 

670 self.dim = dim 

671 

672 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

673 return np.expand_dims(x, axis=self.dim) 

674 

675 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

676 return x.unsqueeze(dim=self.dim) 

677 

678 

679class ToTensor(BaseTransform): 

680 """Converts numpy array or torch tensor to torch tensor of specified dtype. 

681 This transform converts input data to a PyTorch tensor with the specified data type. 

682 It handles both numpy arrays and existing PyTorch tensors as input. 

683 

684 Args: 

685 dtype (str): Target data type for the output tensor. Should be one of PyTorch's 

686 supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.) 

687 

688 Returns: 

689 torch.Tensor: The converted tensor with the specified dtype. 

690 

691 Example: 

692 >>> transform = ToTensor(dtype='float32') 

693 >>> x_numpy = np.array([1, 2, 3]) 

694 >>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor 

695 >>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32) 

696 >>> x_converted = transform(x_existing) # converts to torch.FloatTensor 

697 """ 

698 

699 def __init__(self, dtype: str) -> None: 

700 super().__init__(dtype) 

701 

702 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

703 return torch.as_tensor(x, dtype=self.torch_dtype) 

704 

705 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

706 return x.to(self.torch_dtype)