Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/transforms/functional.py: 72%
57 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2025 Huy Nguyen
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Tuple, Dict
25from types import ModuleType
27# External imports
28import torch
29import numpy as np
32def polsar_dict_to_array(x: np.ndarray | torch.Tensor | Dict[str, np.ndarray]) -> np.ndarray | torch.Tensor:
33 """
34 Convert a dictionary of numpy arrays to a stacked array.
36 Args:
37 x (np.ndarray | torch.Tensor | Dict[str, np.ndarray]): The input data.
38 It can be a single numpy array or PyTorch tensor, or a dictionary where keys are
39 one of 'HH', 'HV', 'VH', 'VV' and values are arrays.
41 Returns:
42 np.ndarray | torch.Tensor: A stacked array from the dictionary's values if input is a dictionary,
43 otherwise returns the input unchanged.
45 Raises:
46 AssertionError: If any key in the dictionary is not one of 'HH', 'HV', 'VH', 'VV'.
47 """
48 if isinstance(x, Dict):
49 for k,v in x.items():
50 assert k in ['HH', 'HV', 'VH', 'VV'], f'Invalid key {k} in input'
51 assert isinstance(v, np.ndarray), "Values must be numpy arrays"
52 return np.stack([
53 v for v in x.values()
54 ])
55 return x
58def check_input(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
59 """Ensure image is in CHW format for 2D-tensors, convert if necessary.
61 Args:
62 x (np.ndarray or torch.Tensor): Input image to check/convert format
64 Returns:
65 np.ndarray or torch.Tensor: Image in CHW format
67 Raises:
68 TypeError: If input is not numpy array or torch tensor
70 Example:
71 >>> img = np.zeros((64, 64)) # HWC format
72 >>> chw_img = check_input(img) # Converts to (1, 64, 64)
73 """
74 if not isinstance(x, (np.ndarray, torch.Tensor)):
75 raise TypeError("Element should be a numpy array or a tensor")
76 if len(x.shape) == 2:
77 return x[np.newaxis, :, :]
78 return x
81def log_normalize_amplitude(
82 x: np.ndarray | torch.Tensor,
83 backend: ModuleType,
84 keep_phase: bool,
85 min_value: float,
86 max_value: float,
87) -> np.ndarray | torch.Tensor:
88 """
89 Normalize the amplitude of a complex signal with logarithmic scaling.
91 Args:
92 x: Input array or tensor containing complex numbers. The type can be either numpy ndarray or PyTorch Tensor.
93 backend: Module providing mathematical functions, allowing compatibility with numpy or PyTorch.
94 keep_phase: Boolean indicating whether to retain the original phase of the input signal.
95 max_value: Maximum amplitude value for normalization.
96 min_value: Minimum amplitude value for normalization.
98 Returns:
99 A numpy ndarray or torch.Tensor containing the log-normalized amplitude, optionally with the original phase.
100 """
101 assert backend.__name__ in ["numpy", "torch"], "Backend must be numpy or torch"
102 amplitude = backend.abs(x)
103 phase = backend.angle(x)
104 amplitude = backend.clip(amplitude, min_value, max_value)
105 transformed_amplitude = (
106 backend.log10(amplitude / min_value)
107 ) / (np.log10(max_value / min_value))
108 if keep_phase:
109 return transformed_amplitude * backend.exp(1j * phase)
110 else:
111 return transformed_amplitude
114def applyfft2_np(x: np.ndarray, axis: Tuple[int, ...]) -> np.ndarray:
115 """Apply 2D Fast Fourier Transform to image.
117 Args:
118 x (np.ndarray): Input array to apply FFT to
119 axis (Tuple[int, ...]): Axes over which to compute the FFT
121 Returns:
122 np.ndarray: The Fourier transformed array
123 """
124 return np.fft.fftshift(np.fft.fft2(x, axes=axis), axes=axis)
127def applyifft2_np(x: np.ndarray, axis: Tuple[int, ...]) -> np.ndarray:
128 """Apply 2D inverse Fast Fourier Transform to image.
130 Args:
131 x (np.ndarray): Input array to apply IFFT to
132 axis (Tuple[int, ...]): Axes over which to compute the IFFT
134 Returns:
135 np.ndarray: The inverse Fourier transformed array
136 """
137 return np.fft.ifft2(np.fft.ifftshift(x, axes=axis), axes=axis)
140def applyfft2_torch(x: torch.Tensor, dim: Tuple[int, ...]) -> torch.Tensor:
141 """Apply 2D Fast Fourier Transform to image.
143 Args:
144 x (np.ndarray): Input array to apply FFT to
145 axis (Tuple[int, ...]): Axes over which to compute the FFT
147 Returns:
148 torch.Tensor: The Fourier transformed array
149 """
150 return torch.fft.fftshift(torch.fft.fft2(x, dim=dim), dim=dim)
153def applyifft2_torch(x: torch.Tensor, dim: Tuple[int, ...]) -> torch.Tensor:
154 """Apply 2D inverse Fast Fourier Transform to image.
156 Args:
157 x (torch.Tensor): Input tensor to apply IFFT to
158 axis (Tuple[int, ...]): Axes over which to compute the IFFT
160 Returns:
161 torch.Tensor: The inverse Fourier transformed array
162 """
163 return torch.fft.ifft2(torch.fft.ifftshift(x, dim=dim), dim=dim)
166def get_padding(current_size: int, target_size: int) -> Tuple[int, ...]:
167 """Calculate padding required to reach target size from current size.
169 Calculates padding values for both sides of an axis to reach a target size.
170 Handles both even and odd target sizes by adjusting padding distribution.
172 Args:
173 current_size (int): Current dimension size
174 target_size (int): Desired dimension size after padding
176 Returns:
177 Tuple[int, ...]: Padding values for (before, after) positions
179 Example:
180 >>> get_padding(5, 7) # Pad 5->7
181 (1, 1) # Add 1 padding on each side
182 >>> get_padding(3, 6) # Pad 3->6 (even target)
183 (2, 1) # More padding before for even targets
185 Note:
186 For even target sizes, the padding is distributed with one extra
187 pad value before the content to maintain proper centering.
188 """
189 # Adjust offset for even-sized targets or odd-sized targets
190 offset = 1 if target_size % 2 == 0 else 0
191 # Calculate total padding needed
192 pad_total = target_size - current_size
193 # Calculate padding before, accounting for even-size offset
194 pad_before = (pad_total + offset) // 2
195 # Calculate padding after as remainder
196 pad_after = pad_total - pad_before
197 return pad_before, pad_after
200def padifneeded(
201 x: np.ndarray | torch.Tensor,
202 min_height: int,
203 min_width: int,
204 border_mode: str = "constant",
205 pad_value: float = 0
206) -> np.ndarray | torch.Tensor:
207 """Pad image if smaller than desired size.
209 This function pads an image with zeros if its dimensions are smaller than the specified
210 minimum height and width. The padding is added equally on both sides where possible.
212 Args:
213 x (Union[np.ndarray, torch.Tensor]): Input image tensor/array with shape (C,H,W)
214 min_height (int): Minimum required height after padding
215 min_width (int): Minimum required width after padding
216 border_mode (str): Padding mode ('constant', 'reflect', 'replicate', etc.)
217 pad_value (float): Value used for padding when border_mode is 'constant'. Default: 0
219 Returns:
220 Union[np.ndarray, torch.Tensor]: Padded image if dimensions were smaller than
221 minimum required, otherwise returns original image unchanged
223 Example:
224 >>> img = torch.randn(3, 50, 60) # RGB image 50x60
225 >>> padded = padifneeded(img, 64, 64, 'constant') # Pads to 64x64
226 >>> padded.shape
227 torch.Size([3, 64, 64])
228 """
229 _, h, w = x.shape
230 # Calculate padding sizes
231 top_pad, bottom_pad = get_padding(h, min_height)
232 left_pad, right_pad = get_padding(w, min_width)
233 padding = [
234 top_pad, # top
235 bottom_pad, # bottom
236 left_pad, # left
237 right_pad, # right
238 ]
239 # Return original if no padding needed
240 if all(p <= 0 for p in padding):
241 return x
243 padding = [max(0, p) for p in padding]
244 if isinstance(x, np.ndarray):
245 return np.pad(
246 x,
247 ((0, 0), (padding[0], padding[1]), (padding[2], padding[3])),
248 mode=border_mode,
249 constant_values=pad_value
250 )
251 return torch.nn.functional.pad(
252 x,
253 (padding[2], padding[3], padding[0], padding[1]),
254 mode=border_mode,
255 value=pad_value
256 )
259def center_crop(x: np.ndarray | torch.Tensor, height: int, width: int) -> np.ndarray | torch.Tensor:
260 """
261 Center crops an image to the specified dimensions.
263 This function takes an image and crops it to the specified height and width,
264 centered around the middle of the image. If the requested dimensions are larger
265 than the image, it will use the maximum possible size.
267 Args:
268 x (Union[np.ndarray, torch.Tensor]): Input image tensor/array with shape (C, H, W)
269 height (int): Desired height of the cropped image
270 width (int): Desired width of the cropped image
272 Returns:
273 Union[np.ndarray, torch.Tensor]: Center cropped image with shape (C, height, width)
275 Example:
276 >>> img = torch.randn(3, 100, 100) # RGB image 100x100
277 >>> cropped = center_crop(img, 60, 60) # Returns center 60x60 crop
278 >>> cropped.shape
279 torch.Size([3, 60, 60])
280 """
281 l_h = max(0, x.shape[0] // 2 - height // 2)
282 l_w = max(0, x.shape[0] // 2 - width // 2)
283 r_h = l_h + height
284 r_w = l_w + width
285 return x[:, l_h:r_h, l_w:r_w]