Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/transforms/functional.py: 72%

57 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Huy Nguyen 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Tuple, Dict 

25from types import ModuleType 

26 

27# External imports 

28import torch 

29import numpy as np 

30 

31 

32def polsar_dict_to_array(x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]) -> np.ndarray | torch.Tensor: 

33 """ 

34 Convert a dictionary of numpy arrays to a stacked array. 

35 

36 Args: 

37 x (np.ndarray | torch.Tensor | Dict[str, np.ndarray]): The input data.  

38 It can be a single numpy array or PyTorch tensor, or a dictionary where keys are 

39 one of 'HH', 'HV', 'VH', 'VV' and values are arrays. 

40 

41 Returns: 

42 np.ndarray | torch.Tensor: A stacked array from the dictionary's values if input is a dictionary, 

43 otherwise returns the input unchanged. 

44  

45 Raises: 

46 AssertionError: If any key in the dictionary is not one of 'HH', 'HV', 'VH', 'VV'. 

47 """ 

48 if isinstance(x, Dict): 

49 for k,v in x.items(): 

50 assert k in ['HH', 'HV', 'VH', 'VV'], f'Invalid key {k} in input' 

51 assert isinstance(v, np.ndarray), "Values must be numpy arrays" 

52 return np.stack([ 

53 v for v in x.values() 

54 ]) 

55 return x 

56 

57 

58def check_input(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 

59 """Ensure image is in CHW format for 2D-tensors, convert if necessary. 

60  

61 Args: 

62 x (np.ndarray or torch.Tensor): Input image to check/convert format 

63  

64 Returns: 

65 np.ndarray or torch.Tensor: Image in CHW format 

66  

67 Raises: 

68 TypeError: If input is not numpy array or torch tensor 

69  

70 Example: 

71 >>> img = np.zeros((64, 64)) # HWC format 

72 >>> chw_img = check_input(img) # Converts to (1, 64, 64) 

73 """ 

74 if not isinstance(x, (np.ndarray, torch.Tensor)): 

75 raise TypeError("Element should be a numpy array or a tensor") 

76 if len(x.shape) == 2: 

77 return x[np.newaxis, :, :] 

78 return x 

79 

80 

81def log_normalize_amplitude( 

82 x: np.ndarray | torch.Tensor, 

83 backend: ModuleType, 

84 keep_phase: bool, 

85 min_value: float, 

86 max_value: float, 

87) -> np.ndarray | torch.Tensor: 

88 """ 

89 Normalize the amplitude of a complex signal with logarithmic scaling. 

90  

91 Args: 

92 x: Input array or tensor containing complex numbers. The type can be either numpy ndarray or PyTorch Tensor. 

93 backend: Module providing mathematical functions, allowing compatibility with numpy or PyTorch. 

94 keep_phase: Boolean indicating whether to retain the original phase of the input signal. 

95 max_value: Maximum amplitude value for normalization. 

96 min_value: Minimum amplitude value for normalization. 

97 

98 Returns: 

99 A numpy ndarray or torch.Tensor containing the log-normalized amplitude, optionally with the original phase. 

100 """ 

101 assert backend.__name__ in ["numpy", "torch"], "Backend must be numpy or torch" 

102 amplitude = backend.abs(x) 

103 phase = backend.angle(x) 

104 amplitude = backend.clip(amplitude, min_value, max_value) 

105 transformed_amplitude = ( 

106 backend.log10(amplitude / min_value) 

107 ) / (np.log10(max_value / min_value)) 

108 if keep_phase: 

109 return transformed_amplitude * backend.exp(1j * phase) 

110 else: 

111 return transformed_amplitude 

112 

113 

114def applyfft2_np(x: np.ndarray, axis: Tuple[int, ...]) -> np.ndarray: 

115 """Apply 2D Fast Fourier Transform to image. 

116  

117 Args: 

118 x (np.ndarray): Input array to apply FFT to 

119 axis (Tuple[int, ...]): Axes over which to compute the FFT 

120  

121 Returns: 

122 np.ndarray: The Fourier transformed array 

123 """ 

124 return np.fft.fftshift(np.fft.fft2(x, axes=axis), axes=axis) 

125 

126 

127def applyifft2_np(x: np.ndarray, axis: Tuple[int, ...]) -> np.ndarray: 

128 """Apply 2D inverse Fast Fourier Transform to image. 

129  

130 Args: 

131 x (np.ndarray): Input array to apply IFFT to 

132 axis (Tuple[int, ...]): Axes over which to compute the IFFT 

133  

134 Returns: 

135 np.ndarray: The inverse Fourier transformed array 

136 """ 

137 return np.fft.ifft2(np.fft.ifftshift(x, axes=axis), axes=axis) 

138 

139 

140def applyfft2_torch(x: torch.Tensor, dim: Tuple[int, ...]) -> torch.Tensor: 

141 """Apply 2D Fast Fourier Transform to image. 

142  

143 Args: 

144 x (np.ndarray): Input array to apply FFT to 

145 axis (Tuple[int, ...]): Axes over which to compute the FFT 

146  

147 Returns: 

148 torch.Tensor: The Fourier transformed array 

149 """ 

150 return torch.fft.fftshift(torch.fft.fft2(x, dim=dim), dim=dim) 

151 

152 

153def applyifft2_torch(x: torch.Tensor, dim: Tuple[int, ...]) -> torch.Tensor: 

154 """Apply 2D inverse Fast Fourier Transform to image. 

155  

156 Args: 

157 x (torch.Tensor): Input tensor to apply IFFT to 

158 axis (Tuple[int, ...]): Axes over which to compute the IFFT 

159  

160 Returns: 

161 torch.Tensor: The inverse Fourier transformed array 

162 """ 

163 return torch.fft.ifft2(torch.fft.ifftshift(x, dim=dim), dim=dim) 

164 

165 

166def get_padding(current_size: int, target_size: int) -> Tuple[int, ...]: 

167 """Calculate padding required to reach target size from current size. 

168  

169 Calculates padding values for both sides of an axis to reach a target size. 

170 Handles both even and odd target sizes by adjusting padding distribution. 

171  

172 Args: 

173 current_size (int): Current dimension size 

174 target_size (int): Desired dimension size after padding 

175  

176 Returns: 

177 Tuple[int, ...]: Padding values for (before, after) positions 

178  

179 Example: 

180 >>> get_padding(5, 7) # Pad 5->7 

181 (1, 1) # Add 1 padding on each side 

182 >>> get_padding(3, 6) # Pad 3->6 (even target) 

183 (2, 1) # More padding before for even targets 

184  

185 Note: 

186 For even target sizes, the padding is distributed with one extra 

187 pad value before the content to maintain proper centering. 

188 """ 

189 # Adjust offset for even-sized targets or odd-sized targets 

190 offset = 1 if target_size % 2 == 0 else 0 

191 # Calculate total padding needed 

192 pad_total = target_size - current_size 

193 # Calculate padding before, accounting for even-size offset 

194 pad_before = (pad_total + offset) // 2 

195 # Calculate padding after as remainder 

196 pad_after = pad_total - pad_before 

197 return pad_before, pad_after 

198 

199 

200def padifneeded( 

201 x: np.ndarray | torch.Tensor, 

202 min_height: int, 

203 min_width: int, 

204 border_mode: str = "constant", 

205 pad_value: float = 0 

206) -> np.ndarray | torch.Tensor: 

207 """Pad image if smaller than desired size. 

208 

209 This function pads an image with zeros if its dimensions are smaller than the specified 

210 minimum height and width. The padding is added equally on both sides where possible. 

211 

212 Args: 

213 x (Union[np.ndarray, torch.Tensor]): Input image tensor/array with shape (C,H,W) 

214 min_height (int): Minimum required height after padding 

215 min_width (int): Minimum required width after padding  

216 border_mode (str): Padding mode ('constant', 'reflect', 'replicate', etc.) 

217 pad_value (float): Value used for padding when border_mode is 'constant'. Default: 0 

218 

219 Returns: 

220 Union[np.ndarray, torch.Tensor]: Padded image if dimensions were smaller than  

221 minimum required, otherwise returns original image unchanged 

222 

223 Example: 

224 >>> img = torch.randn(3, 50, 60) # RGB image 50x60 

225 >>> padded = padifneeded(img, 64, 64, 'constant') # Pads to 64x64 

226 >>> padded.shape 

227 torch.Size([3, 64, 64]) 

228 """ 

229 _, h, w = x.shape 

230 # Calculate padding sizes 

231 top_pad, bottom_pad = get_padding(h, min_height) 

232 left_pad, right_pad = get_padding(w, min_width) 

233 padding = [ 

234 top_pad, # top 

235 bottom_pad, # bottom 

236 left_pad, # left 

237 right_pad, # right 

238 ] 

239 # Return original if no padding needed 

240 if all(p <= 0 for p in padding): 

241 return x 

242 

243 padding = [max(0, p) for p in padding] 

244 if isinstance(x, np.ndarray): 

245 return np.pad( 

246 x, 

247 ((0, 0), (padding[0], padding[1]), (padding[2], padding[3])), 

248 mode=border_mode, 

249 constant_values=pad_value 

250 ) 

251 return torch.nn.functional.pad( 

252 x, 

253 (padding[2], padding[3], padding[0], padding[1]), 

254 mode=border_mode, 

255 value=pad_value 

256 ) 

257 

258 

259def center_crop(x: np.ndarray | torch.Tensor, height: int, width: int) -> np.ndarray | torch.Tensor: 

260 """ 

261 Center crops an image to the specified dimensions. 

262 

263 This function takes an image and crops it to the specified height and width,  

264 centered around the middle of the image. If the requested dimensions are larger  

265 than the image, it will use the maximum possible size. 

266 

267 Args: 

268 x (Union[np.ndarray, torch.Tensor]): Input image tensor/array with shape (C, H, W) 

269 height (int): Desired height of the cropped image 

270 width (int): Desired width of the cropped image 

271 

272 Returns: 

273 Union[np.ndarray, torch.Tensor]: Center cropped image with shape (C, height, width) 

274 

275 Example: 

276 >>> img = torch.randn(3, 100, 100) # RGB image 100x100 

277 >>> cropped = center_crop(img, 60, 60) # Returns center 60x60 crop 

278 >>> cropped.shape 

279 torch.Size([3, 60, 60]) 

280 """ 

281 l_h = max(0, x.shape[0] // 2 - height // 2) 

282 l_w = max(0, x.shape[0] // 2 - width // 2) 

283 r_h = l_h + height 

284 r_w = l_w + width 

285 return x[:, l_h:r_h, l_w:r_w]