Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / transforms / transforms.py: 71%

244 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-21 08:33 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Quentin Gabot, Jeremy Fix, Huy Nguyen 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from abc import ABC, abstractmethod 

25from typing import Tuple, Union, Optional, Dict 

26from types import ModuleType 

27 

28# External imports 

29import torch 

30import numpy as np 

31from PIL import Image 

32 

33# Internal imports 

34import torchcvnn.transforms.functional as F 

35 

36 

37class BaseTransform(ABC): 

38 """Abstract base class for transforms that can handle both numpy arrays and PyTorch tensors. 

39 This class serves as a template for implementing transforms that can be applied to both numpy arrays 

40 and PyTorch tensors while maintaining consistent behavior. 

41 Inputs must be in CHW (Channel, Height, Width) format. If inputs have only 2 dimensions (Height, Width), 

42 they will be converted to (1, Height, Width). 

43 

44 Args: 

45 dtype (str, optional): Data type to convert inputs to. Must be one of: 

46 'float32', 'float64', 'complex64', 'complex128'. If None, no type conversion is performed. 

47 Default: None. 

48 

49 Raises: 

50 AssertionError: If dtype is not a string or not one of the allowed types. 

51 ValueError: If input is neither a numpy array nor a PyTorch tensor. 

52 

53 Methods: 

54 __call__(x): Apply the transform to the input array/tensor. 

55 __call_numpy__(x): Abstract method to implement numpy array transform. 

56 __call_torch__(x): Abstract method to implement PyTorch tensor transform. 

57 

58 Example: 

59 >>> class MyTransform(BaseTransform): 

60 >>> def __call_numpy__(self, x): 

61 >>> # Implement numpy transform 

62 >>> pass 

63 >>> def __call_torch__(self, x): 

64 >>> # Implement torch transform 

65 >>> pass 

66 >>> transform = MyTransform(dtype='float32') 

67 >>> output = transform(input_data) # Works with both numpy arrays and torch tensors 

68 """ 

69 

70 def __init__(self, dtype: str = None) -> None: 

71 if dtype is not None: 

72 assert isinstance(dtype, str), "dtype should be a string" 

73 assert dtype in [ 

74 "float32", 

75 "float64", 

76 "complex64", 

77 "complex128", 

78 ], "dtype should be one of float32, float64, complex64, complex128" 

79 self.np_dtype = getattr(np, dtype) 

80 self.torch_dtype = getattr(torch, dtype) 

81 

82 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

83 """Apply transform to input.""" 

84 x = F.check_input(x) 

85 if isinstance(x, np.ndarray): 

86 return self.__call_numpy__(x) 

87 elif isinstance(x, torch.Tensor): 

88 return self.__call_torch__(x) 

89 

90 @abstractmethod 

91 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

92 """Apply transform to numpy array.""" 

93 raise NotImplementedError 

94 

95 @abstractmethod 

96 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

97 """Apply transform to torch tensor.""" 

98 raise NotImplementedError 

99 

100 

101class LogAmplitude(BaseTransform): 

102 """This transform applies a logarithmic scaling to the amplitude/magnitude of complex values 

103 while optionally preserving the phase information. The amplitude is first clipped to 

104 [min_value, max_value] range, then log10-transformed and normalized to [0,1] range. 

105 

106 The transformation follows these steps: 

107 1. Extract amplitude and phase from complex input 

108 2. Clip amplitude between min_value and max_value 

109 3. Apply log10 transform and normalize to [0,1] 

110 4. Optionally recombine with original phase 

111 

112 Args: 

113 min_value (int | float, optional): Minimum amplitude value for clipping. 

114 Values below this will be clipped up. Defaults to 0.02. 

115 max_value (int | float, optional): Maximum amplitude value for clipping. 

116 Values above this will be clipped down. Defaults to 40. 

117 keep_phase (bool, optional): Whether to preserve phase information. 

118 If True, returns complex output with transformed amplitude and original phase. 

119 If False, returns just the transformed amplitude. Defaults to True. 

120 Returns: 

121 np.ndarray | torch.Tensor: Transformed tensor with same shape as input. 

122 If keep_phase=True: Complex tensor with log-scaled amplitude and original phase 

123 If keep_phase=False: Real tensor with just the log-scaled amplitude 

124 Example: 

125 >>> transform = LogAmplitude(min_value=0.01, max_value=100) 

126 >>> output = transform(input_tensor) # Transforms amplitudes to log scale [0,1] 

127 Note: 

128 The transform works with both NumPy arrays and PyTorch tensors through 

129 separate internal implementations (__call_numpy__ and __call_torch__). 

130 """ 

131 

132 def __init__( 

133 self, min_value: float = 0.02, max_value: float = 40, keep_phase: bool = True 

134 ) -> None: 

135 self.min_value = min_value 

136 self.max_value = max_value 

137 self.keep_phase = keep_phase 

138 

139 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

140 return F.log_normalize_amplitude( 

141 x, np, self.keep_phase, self.min_value, self.max_value 

142 ) 

143 

144 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

145 return F.log_normalize_amplitude( 

146 x, torch, self.keep_phase, self.min_value, self.max_value 

147 ) 

148 

149class Normalize(BaseTransform): 

150 """Per-channel 2x2 normalization of [Re, Im].  

151 This transform normalizes complex-valued input data by centering and scaling it. 

152 It supports both numpy arrays and PyTorch tensors as input. 

153 Args: 

154 means (array-like): Per-channel means for centering. Shape (C, 2) where C is number of channels. 

155 covs (array-like): Per-channel 2x2 covariance matrices for scaling. Shape (C, 2, 2). 

156 Covariance matrices must be symmetric positive definite. 

157 Returns: 

158 np.ndarray | torch.Tensor: Normalized data with same shape as input. 

159 Each channel is independently centered and scaled. 

160 Example: 

161 >>> means = [[0,0], [1,1]] 

162 >>> covs = [[[1,0],[0,1]], [[2,0],[0,2]]] 

163 >>> transform = Normalize(means, covs) 

164 >>> output = transform(input_data) # Normalizes each channel independently 

165 """ 

166 

167 def __init__(self, means, covs, eps=1e-12): 

168 # means: (C,2) ; covs: (C,2,2) 

169 self.means = np.asarray(means, dtype=np.float64) 

170 self.covs = np.asarray(covs, dtype=np.float64) 

171 assert self.means.ndim == 2 and self.means.shape[1] == 2 

172 assert self.covs.ndim == 3 and self.covs.shape[1:] == (2,2) 

173 self.num_channels = self.means.shape[0] 

174 

175 # NEW: precompute per-channel whitening matrices W = Σ^{-1/2} once 

176 covs_sym = 0.5 * (self.covs + np.swapaxes(self.covs, -1, -2)) # NEW 

177 w, V = np.linalg.eigh(covs_sym) # NEW 

178 w = np.maximum(w, eps) # NEW 

179 Dinv = np.zeros_like(self.covs) # NEW 

180 Dinv[:,0,0] = 1.0 / np.sqrt(w[:,0]) 

181 Dinv[:,1,1] = 1.0 / np.sqrt(w[:,1]) 

182 self.W_np = V @ Dinv @ np.transpose(V, (0,2,1)) # NEW 

183 super().__init__() 

184 

185 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

186 if x.ndim != 3 or not np.iscomplexobj(x): 

187 raise ValueError("Expect complex CHW ndarray.") 

188 C, H, W = x.shape 

189 if C != self.num_channels: 

190 raise ValueError("Channel mismatch.") 

191 

192 # Stack to (C,2,H,W) 

193 Z = np.stack([x.real, x.imag], axis=1) # FIX: (C,2,H,W) 

194 Zc = Z - self.means[:, :, None, None] # center 

195 

196 # FIX: correct contraction over the 2-dim 

197 Y = np.einsum('cij,cjhw->cihw', self.W_np, Zc) # (C,2,H,W) 

198 

199 return Y[:,0] + 1j * Y[:,1] 

200 

201 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

202 if x.dim() != 3 or not x.is_complex(): 

203 raise ValueError("Expect complex CHW tensor.") 

204 C, H, W = x.shape 

205 if C != self.num_channels: 

206 raise ValueError("Channel mismatch.") 

207 

208 # Move precomputed params to device/dtype 

209 Wc = torch.as_tensor(self.W_np, device=x.device, dtype=x.real.dtype) # NEW 

210 muc = torch.as_tensor(self.means, device=x.device, dtype=x.real.dtype) # NEW 

211 

212 # (C,2,H,W) 

213 Z = torch.stack([x.real, x.imag], dim=1) # FIX 

214 Zc = Z - muc.view(C,2,1,1) 

215 

216 # FIX: correct contraction over 2-dim 

217 Y = torch.einsum('cij,cjhw->cihw', Wc, Zc) 

218 

219 return torch.complex(Y[:,0], Y[:,1]) 

220 

221class Amplitude(BaseTransform): 

222 """Transform a complex-valued tensor into its amplitude/magnitude. 

223 

224 This transform computes the absolute value (magnitude) of complex input data, 

225 converting complex values to real values. 

226 

227 Args: 

228 dtype (str): Data type for the output ('float32', 'float64', etc) 

229 

230 Returns: 

231 np.ndarray | torch.Tensor: Real-valued tensor containing the amplitudes, 

232 with same shape as input but real-valued type specified by dtype. 

233 """ 

234 

235 def __init__(self, dtype: str) -> None: 

236 super().__init__(dtype) 

237 

238 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

239 return torch.abs(x).to(self.torch_dtype) 

240 

241 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

242 return np.abs(x).astype(self.np_dtype) 

243 

244 

245class RealImaginary(BaseTransform): 

246 """Transform a complex-valued tensor into its real and imaginary components. 

247 

248 This transform separates a complex-valued tensor into its real and imaginary parts, 

249 stacking them along a new channel dimension. The output tensor has twice the number 

250 of channels as the input. 

251 

252 Returns: 

253 np.ndarray | torch.Tensor: Real-valued tensor containing real and imaginary parts, 

254 with shape (2*C, H, W) where C is the original number of channels. 

255 """ 

256 

257 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

258 x = torch.stack([x.real, x.imag], dim=0) # CHW -> 2CHW 

259 x = x.flatten(0, 1) # 2CHW -> 2C*H*W 

260 return x 

261 

262 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

263 x = np.stack([x.real, x.imag], axis=0) # CHW -> 2CHW 

264 x = x.reshape(-1, *x.shape[2:]) # 2CHW -> 2C*H*W 

265 return x 

266 

267 

268class RandomPhase(BaseTransform): 

269 """Randomly phase-shifts complex-valued input data. 

270 This transform applies a random phase shift to complex-valued input tensors/arrays by 

271 multiplying the input with exp(j*phi), where phi is uniformly distributed in [0, 2π] 

272 or [-π, π] if centering is enabled. 

273 Args: 

274 dtype : str 

275 Data type for the output. Must be one of the supported complex dtypes. 

276 centering : bool, optional. 

277 If True, centers the random phase distribution around 0 by subtracting π from 

278 the generated phases. Default is False. 

279 Returns 

280 torch.Tensor or numpy.ndarray 

281 Phase-shifted complex-valued data with the same shape as input. 

282 

283 Examples 

284 >>> transform = RandomPhase(dtype='complex64') 

285 >>> x = torch.ones(3,3, dtype=torch.complex64) 

286 >>> output = transform(x) # Applies random phase shifts 

287 

288 Notes 

289 - Input data must be complex-valued 

290 - The output maintains the same shape and complex dtype as input 

291 - Phase shifts are uniformly distributed in: 

292 - [0, 2π] when centering=False 

293 - [-π, π] when centering=True 

294 """ 

295 

296 def __init__(self, dtype: str, centering: bool = False) -> None: 

297 super().__init__(dtype) 

298 self.centering = centering 

299 

300 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

301 phase = torch.rand_like(x) * 2 * torch.pi 

302 if self.centering: 

303 phase = phase - torch.pi 

304 return (x * torch.exp(1j * phase)).to(self.torch_dtype) 

305 

306 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

307 phase = np.random.rand(*x.shape) * 2 * np.pi 

308 if self.centering: 

309 phase = phase - np.pi 

310 return (x * np.exp(1j * phase)).astype(self.np_dtype) 

311 

312 

313class ToReal: 

314 """Extracts the real part of a complex-valued input tensor. 

315 

316 The `ToReal` transform takes either a numpy array or a PyTorch tensor containing complex numbers 

317 and returns only their real parts. If the input is already real-valued, it remains unchanged. 

318 

319 Returns: 

320 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only 

321 the real components of each element. 

322 

323 Example: 

324 >>> to_real = ToReal() 

325 >>> output = to_real(complex_tensor) 

326 """ 

327 

328 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

329 return x.real 

330 

331 

332class ToImaginary: 

333 """Extracts the imaginary part of a complex-valued input tensor. 

334 

335 The `ToImaginary` transform takes either a numpy array or a PyTorch tensor containing complex numbers 

336 and returns only their imaginary parts. If the input is already real-valued, it remains unchanged. 

337 

338 Returns: 

339 np.ndarray | torch.Tensor: A tensor with the same shape as the input but containing only 

340 the imaginary components of each element. 

341 

342 Example: 

343 >>> to_imaginary = ToImaginary() 

344 >>> output = to_imaginary(complex_tensor) 

345 """ 

346 

347 def __call__(self, x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

348 return x.imag 

349 

350 

351class FFT2(BaseTransform): 

352 """Applies 2D Fast Fourier Transform (FFT) to the input. 

353 This transform computes the 2D FFT along specified dimensions of the input array/tensor. 

354 It applies FFT2 and shifts zero-frequency components to the center. 

355 

356 Args 

357 axis : Tuple[int, ...], optional 

358 The axes over which to compute the FFT. Default is (-2, -1). 

359 

360 Returns 

361 numpy.ndarray or torch.Tensor 

362 The 2D Fourier transformed input with zero-frequency components centered. 

363 Output has the same shape as input. 

364 

365 Notes 

366 - Transform is applied along specified dimensions (`axis`). 

367 """ 

368 

369 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

370 self.axis = axis 

371 

372 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

373 return F.applyfft2_np(x, axis=self.axis) 

374 

375 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

376 return F.applyfft2_torch(x, dim=self.axis) 

377 

378 

379class IFFT2(BaseTransform): 

380 """Applies 2D inverse Fast Fourier Transform (IFFT) to the input. 

381 This transform computes the 2D IFFT along the last two dimensions of the input array/tensor. 

382 It applies inverse FFT shift before IFFT2. 

383 

384 Args 

385 axis : Tuple[int, ...], optional 

386 The axes over which to compute the FFT. Default is (-2, -1). 

387 

388 Returns 

389 numpy.ndarray or torch.Tensor: 

390 The inverse Fourier transformed input. 

391 Output has the same shape as input. 

392 

393 Notes: 

394 - Transform is applied along specified dimensions (`axis`). 

395 """ 

396 

397 def __init__(self, axis: Tuple[int, ...] = (-2, -1)): 

398 self.axis = axis 

399 

400 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

401 return F.applyifft2_np(x, axis=self.axis) 

402 

403 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

404 return F.applyifft2_torch(x, dim=self.axis) 

405 

406 

407class PadIfNeeded(BaseTransform): 

408 """Pad an image if its dimensions are smaller than specified minimum dimensions. 

409 

410 This transform pads images that are smaller than given minimum dimensions by adding 

411 padding according to the specified border mode. The padding is added symmetrically 

412 on both sides to reach the minimum dimensions when possible. If the minimum required 

413 dimension (height or width) is uneven, the right and the bottom sides will receive 

414 an extra padding of 1 compared to the left and the top sides. 

415 

416 Args: 

417 min_height (int): Minimum height requirement for the image 

418 min_width (int): Minimum width requirement for the image 

419 border_mode (str): Type of padding to apply ('constant', 'reflect', etc.). Default is 'constant'. 

420 pad_value (float): Value for constant padding (if applicable). Default is 0. 

421 

422 Returns: 

423 np.ndarray | torch.Tensor: Padded image with dimensions at least min_height x min_width. 

424 Original image if no padding is required. 

425 

426 Example: 

427 >>> transform = PadIfNeeded(min_height=256, min_width=256) 

428 >>> padded_image = transform(small_image) # Pads if image is smaller than 256x256 

429 """ 

430 

431 def __init__( 

432 self, 

433 min_height: int, 

434 min_width: int, 

435 border_mode: str = "constant", 

436 pad_value: float = 0, 

437 ) -> None: 

438 self.min_height = min_height 

439 self.min_width = min_width 

440 self.border_mode = border_mode 

441 self.pad_value = pad_value 

442 

443 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

444 return F.padifneeded( 

445 x, self.min_height, self.min_width, self.border_mode, self.pad_value 

446 ) 

447 

448 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

449 return F.padifneeded( 

450 x, self.min_height, self.min_width, self.border_mode, self.pad_value 

451 ) 

452 

453 

454class CenterCrop(BaseTransform): 

455 """Center crops an input array/tensor to the specified size. 

456 

457 This transform extracts a centered rectangular region from the input array/tensor 

458 with the specified dimensions. The crop is centered on both height and width axes. 

459 

460 Args: 

461 height (int): Target height of the cropped output 

462 width (int): Target width of the cropped output 

463 

464 Returns: 

465 np.ndarray | torch.Tensor: Center cropped array/tensor with shape (C, height, width), 

466 where C is the number of channels 

467 

468 Examples: 

469 >>> transform = CenterCrop(height=224, width=224) 

470 >>> output = transform(input_tensor) # Center crops to 224x224 

471 

472 Notes: 

473 - If input is smaller than crop size, it will return the original input 

474 - Crop is applied identically to all channels 

475 - Uses functional.center_crop() implementation for both numpy and torch 

476 """ 

477 

478 def __init__(self, height: int, width: int) -> None: 

479 self.height = height 

480 self.width = width 

481 

482 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

483 return F.center_crop(x, self.height, self.width) 

484 

485 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

486 return F.center_crop(x, self.height, self.width) 

487 

488 

489class FFTResize(BaseTransform): 

490 """Resizes an input image in spectral domain with Fourier Transformations. 

491 

492 This transform first applies a 2D FFT to the input array/tensor of shape CHW along specified axes, 

493 followed by padding or center cropping to achieve the target size, then applies 

494 an inverse FFT to go back to spatial domain. Optionally, it scales the output amplitudes to maintain energy consistency 

495 between original and resized images. 

496 

497 Args: 

498 size: Tuple[int, int] 

499 Target dimensions (height, width) for resizing. 

500 axis: Tuple[int, ...], optional 

501 The axes over which to apply FFT. Default is (-2, -1). For a array / tensor of shape CHW, 

502 it corresponds to the Height and Width axes. 

503 scale: bool, optional 

504 If True, scales the output amplitudes to maintain energy consistency with 

505 respect to input size. Default is False. 

506 dtype: torch.dtype or numpy.dtype, optional 

507 Output data type. If None, maintains the input data type. 

508 For PyTorch tensors: torch.complex64 or torch.complex128 

509 For NumPy arrays: numpy.complex64 or numpy.complex128 

510 

511 Returns: 

512 numpy.ndarray or torch.Tensor 

513 Resized image as a complex-valued array/tensor, maintaining shape (C, height, width). 

514 

515 Examples: 

516 >>> transform = FFTResize((128, 128)) 

517 >>> resized_image = transform(input_tensor) # Resize to 128x128 using FFT 

518 

519 Notes: 

520 - Input must be a multi-dimensional array/tensor of shape Channel x Height x Width. 

521 - Spectral domain resizing preserves frequency characteristics better than spatial interpolation 

522 - Operates on complex-valued data, preserving phase information 

523 - Memory efficient for large downsampling ratios 

524 - Based on the Fourier Transform properties of scaling and periodicity 

525 - The output is complex-valued due to the nature of FFT operations. If you are working with real-valued data, 

526 it is recommended to call ToReal after applying this transform. 

527 """ 

528 

529 def __init__( 

530 self, 

531 size: Tuple[int, ...], 

532 axis: Tuple[int, ...] = (-2, -1), 

533 scale: bool = False, 

534 dtype: Optional[str] = "complex64", 

535 ) -> None: 

536 if dtype is None or "complex" not in str(dtype): 

537 dtype = "complex64" 

538 

539 super().__init__(dtype) 

540 assert isinstance(size, Tuple), "size must be a tuple" 

541 assert isinstance(axis, Tuple), "axis must be a tuple" 

542 self.height = size[0] 

543 self.width = size[1] 

544 self.axis = axis 

545 self.scale = scale 

546 

547 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

548 original_size = x.shape[1] * x.shape[2] 

549 target_size = self.height * self.width 

550 

551 x = F.applyfft2_np(x, axis=self.axis) 

552 x = F.padifneeded(x, self.height, self.width) 

553 x = F.center_crop(x, self.height, self.width) 

554 x = F.applyifft2_np(x, axis=self.axis) 

555 

556 if self.scale: 

557 return x * target_size / original_size 

558 return x.astype(self.np_dtype) 

559 

560 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

561 original_size = x.shape[1] * x.shape[2] 

562 target_size = self.height * self.width 

563 

564 x = F.applyfft2_torch(x, dim=self.axis) 

565 x = F.padifneeded(x, self.height, self.width) 

566 x = F.center_crop(x, self.height, self.width) 

567 x = F.applyifft2_torch(x, dim=self.axis) 

568 

569 if self.scale: 

570 return x * target_size / original_size 

571 return x.to(self.torch_dtype) 

572 

573 

574class SpatialResize: 

575 """ 

576 Resize a complex tensor to a given size. The resize is performed in the image space 

577 using a Bicubic interpolation. 

578 

579 Arguments: 

580 size: The target size of the resized tensor. 

581 """ 

582 

583 def __init__(self, size): 

584 self.size = size 

585 

586 def __call__( 

587 self, array: Union[np.ndarray, torch.Tensor] 

588 ) -> Union[np.ndarray, torch.Tensor]: 

589 is_torch = False 

590 if isinstance(array, torch.Tensor): 

591 is_torch = True 

592 array = array.numpy() 

593 

594 real_part = array.real 

595 imaginary_part = array.imag 

596 

597 def zoom(array): 

598 # Convert the numpy array to a PIL image 

599 image = Image.fromarray(array) 

600 

601 # Resize the image 

602 image = image.resize((self.size[1], self.size[0])) 

603 

604 # Convert the PIL image back to a numpy array 

605 array = np.array(image) 

606 

607 return array 

608 

609 if len(array.shape) == 2: 

610 # We have a two dimensional tensor 

611 resized_real = zoom(real_part) 

612 resized_imaginary = zoom(imaginary_part) 

613 else: 

614 # We have three dimensions and therefore 

615 # apply the resize to each channel iteratively 

616 # We assume the first dimension is the channel 

617 resized_real = [] 

618 resized_imaginary = [] 

619 for real, imaginary in zip(real_part, imaginary_part): 

620 resized_real.append(zoom(real)) 

621 resized_imaginary.append(zoom(imaginary)) 

622 resized_real = np.stack(resized_real) 

623 resized_imaginary = np.stack(resized_imaginary) 

624 

625 resized_array = resized_real + 1j * resized_imaginary 

626 

627 # Convert the resized tensor back to a torch tensor if necessary 

628 if is_torch: 

629 resized_array = torch.as_tensor(resized_array) 

630 

631 return resized_array 

632 

633 

634class PolSAR(BaseTransform): 

635 """Handling Polarimetric Synthetic Aperture Radar (PolSAR) data channel conversions. 

636 This class provides functionality to convert between different channel representations of PolSAR data, 

637 supporting 1, 2, 3, and 4 output channel configurations. It can handle both NumPy arrays and PyTorch tensors. 

638 If inputs is a dictionnary of type {'HH': data1, 'VV': data2}, it will stack all values along axis 0 to form a CHW array. 

639 

640 Args: 

641 out_channel (int): Desired number of output channels (1, 2, 3, or 4) 

642 

643 Supported conversions: 

644 - 1 channel -> 1 channel: Identity 

645 - 2 channels -> 1 or 2 channels 

646 - 3 channels -> 1, 2, or 3 channels where: 

647 - 1 channel: Returns first channel only 

648 - 2 channels: Returns [HH, VV] channels 

649 - 3 channels: Returns all channels [HH, HV, VV] 

650 - 4 channels -> 1, 2, 3, or 4 channels where: 

651 - 1 channel: Returns first channel only 

652 - 2 channels: Returns [HH, VV] channels 

653 - 3 channels: Returns [HH, (HV+VH)/2, VV] 

654 - 4 channels: Returns all channels [HH, HV, VH, VV] 

655 

656 Raises: 

657 ValueError: If the requested channel conversion is invalid or not supported 

658 

659 Example: 

660 >>> transform = PolSAR(out_channel=3) 

661 >>> # For 4-channel input [HH, HV, VH, VV] 

662 >>> output = transform(input_data) # Returns [HH, (HV+VH)/2, VV] 

663 

664 Note: 

665 - Input data should have format Channels x Height x Width (CHW). 

666 - By default, PolSAR always return HH polarization if out_channel is 1. 

667 """ 

668 

669 def __init__(self, out_channel: int) -> None: 

670 self.out_channel = out_channel 

671 

672 def _handle_single_channel( 

673 self, x: np.ndarray | torch.Tensor, out_channels: int 

674 ) -> np.ndarray | torch.Tensor: 

675 return x if out_channels == 1 else None 

676 

677 def _handle_two_channels( 

678 self, x: np.ndarray | torch.Tensor, out_channels: int 

679 ) -> np.ndarray | torch.Tensor: 

680 if out_channels == 2: 

681 return x 

682 elif out_channels == 1: 

683 return x[0:1] 

684 return None 

685 

686 def _handle_three_channels( 

687 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType 

688 ) -> np.ndarray | torch.Tensor: 

689 channel_maps = { 

690 1: lambda: x[0:1], 

691 2: lambda: backend.stack((x[0], x[2])), 

692 3: lambda: backend.stack((x[0], x[1], x[2])), 

693 } 

694 return channel_maps.get(out_channels, lambda: None)() 

695 

696 

697 def _handle_four_channels( 

698 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType 

699 ) -> np.ndarray | torch.Tensor: 

700 channel_maps = { 

701 1: lambda: x[0:1], 

702 2: lambda: backend.stack((x[0], x[3])), 

703 3: lambda: backend.stack((x[0], 0.5 * (x[1] + x[2]), x[3])), 

704 4: lambda: x, 

705 } 

706 return channel_maps.get(out_channels, lambda: None)() 

707 

708 def _convert_channels( 

709 self, x: np.ndarray | torch.Tensor, out_channels: int, backend: ModuleType 

710 ) -> np.ndarray | torch.Tensor: 

711 handlers = { 

712 1: self._handle_single_channel, 

713 2: self._handle_two_channels, 

714 3: lambda x, o: self._handle_three_channels(x, o, backend), 

715 4: lambda x, o: self._handle_four_channels(x, o, backend), 

716 } 

717 result = handlers.get(x.shape[0], lambda x, o: None)(x, out_channels) 

718 if result is None: 

719 raise ValueError( 

720 f"Invalid conversion: {x.shape[0]} -> {out_channels} channels" 

721 ) 

722 return result 

723 

724 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

725 return self._convert_channels(x, self.out_channel, np) 

726 

727 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

728 return self._convert_channels(x, self.out_channel, torch) 

729 

730 def __call__( 

731 self, x: np.ndarray | torch.Tensor | Dict[str, np.ndarray] 

732 ) -> np.ndarray | torch.Tensor: 

733 x = F.polsar_dict_to_array(x) 

734 return super().__call__(x) 

735 

736 

737class Unsqueeze(BaseTransform): 

738 """Add a singleton dimension to the input array/tensor. 

739 

740 This transform inserts a new axis at the specified position, increasing 

741 the dimensionality of the input by one. 

742 

743 Args: 

744 dim (int): Position where new axis should be inserted. 

745 

746 Returns: 

747 np.ndarray | torch.Tensor: Input with new singleton dimension added. 

748 Shape will be same as input but with a 1 inserted at position dim. 

749 

750 Example: 

751 >>> transform = Unsqueeze(dim=0) 

752 >>> x = torch.randn(3,4) # Shape (3,4) 

753 >>> y = transform(x) # Shape (1,3,4) 

754 """ 

755 

756 def __init__(self, dim: int) -> None: 

757 self.dim = dim 

758 

759 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

760 return np.expand_dims(x, axis=self.dim) 

761 

762 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

763 return x.unsqueeze(dim=self.dim) 

764 

765 

766class ToTensor(BaseTransform): 

767 """Converts numpy array or torch tensor to torch tensor of specified dtype. 

768 This transform converts input data to a PyTorch tensor with the specified data type. 

769 It handles both numpy arrays and existing PyTorch tensors as input. 

770 

771 Args: 

772 dtype (str): Target data type for the output tensor. Should be one of PyTorch's 

773 supported dtype strings (e.g. 'float32', 'float64', 'int32', etc.) 

774 

775 Returns: 

776 torch.Tensor: The converted tensor with the specified dtype. 

777 

778 Example: 

779 >>> transform = ToTensor(dtype='float32') 

780 >>> x_numpy = np.array([1, 2, 3]) 

781 >>> x_tensor = transform(x_numpy) # converts to torch.FloatTensor 

782 >>> x_existing = torch.tensor([1, 2, 3], dtype=torch.int32) 

783 >>> x_converted = transform(x_existing) # converts to torch.FloatTensor 

784 """ 

785 

786 def __init__(self, dtype: str) -> None: 

787 super().__init__(dtype) 

788 

789 def __call_numpy__(self, x: np.ndarray) -> np.ndarray: 

790 return torch.as_tensor(x, dtype=self.torch_dtype) 

791 

792 def __call_torch__(self, x: torch.Tensor) -> torch.Tensor: 

793 return x.to(self.torch_dtype)