Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/bretigny/dataset.py: 0%
58 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Chengfang Ren, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24import pathlib
25from typing import Tuple, Any
27# External imports
28from torch.utils.data import Dataset
29import numpy as np
32class Bretigny(Dataset):
33 r"""
34 Bretigny Dataset
36 Arguments:
37 root: the root directory containing the npz files for Bretigny
38 fold: train (70%), valid (15%), or test (15%)
39 transform : the transform applied the cropped image
40 balanced: whether or not to use balanced labels
41 patch_size: the dimensions of the patches to consider (rows, cols)
42 patch_stride: the shift between two consecutive patches, default:patch_size
44 Note:
45 An example usage :
47 .. code-block:: python
49 import torchcvnn
50 from torchcvnn.datasets import Bretigny
52 dataset = Bretigny(
53 rootdir, fold="train", patch_size=((128, 128)), transform=lambda x: np.abs(x)
54 )
55 X, y = dataset[0]
57 Displayed below are the train, valid and test parts with the labels overlayed
59 .. image:: ../assets/datasets/bretigny_train.png
60 :alt: Train fold
61 :width: 60%
62 .. image:: ../assets/datasets/bretigny_valid.png
63 :alt: Valid fold
64 :width: 15%
65 .. image:: ../assets/datasets/bretigny_test.png
66 :alt: Test fold
67 :width: 15%
69 """
71 """
72 Class names
73 """
74 classes = ["0 - Unlabeld", "1 - Forest", "2 - Track", "3 - Urban", "4 - Fields"]
76 def __init__(
77 self,
78 root: str,
79 fold: str,
80 transform=None,
81 balanced: bool = False,
82 patch_size: tuple = (128, 128),
83 patch_stride: tuple = None,
84 ):
85 self.root = pathlib.Path(root)
86 self.transform = transform
88 self.patch_size = patch_size
89 self.patch_stride = patch_stride
90 if patch_stride is None:
91 self.patch_stride = patch_size
93 # Preload the data
94 sar_filename = self.root / "bretigny_seg.npz"
95 if not sar_filename.exists():
96 raise RuntimeError(f"Cannot find the file {sar_filename}")
97 sar_data = np.load(sar_filename)
98 self.HH, self.HV, self.VV = sar_data["HH"], sar_data["HV"], sar_data["VV"]
100 if balanced:
101 label_filename = self.root / "bretigny_seg_4ROI_balanced.npz"
102 else:
103 label_filename = self.root / "bretigny_seg_4ROI.npz"
104 if not label_filename.exists():
105 raise RuntimeError(f"Cannot find the label file {label_filename}")
106 self.labels = np.load(label_filename)["arr_0"]
108 if not fold in ["train", "valid", "test"]:
109 raise ValueError(
110 f"Unrecognized fold {fold}. Should be either train, valid or test"
111 )
113 # Crop the data with respect to the fold
114 if fold == "train":
115 col_start = 0
116 col_end = int(0.70 * self.HH.shape[1])
117 elif fold == "valid":
118 col_start = int(0.70 * self.HH.shape[1]) + 1
119 col_end = int(0.85 * self.HH.shape[1])
120 else:
121 col_start = int(0.85 * self.HH.shape[1]) + 1
122 col_end = self.HH.shape[1]
124 self.HH = self.HH[:, col_start:col_end]
125 self.HV = self.HV[:, col_start:col_end]
126 self.VV = self.VV[:, col_start:col_end]
127 self.labels = self.labels[:, col_start:col_end]
129 # Precompute the dimension of the grid of patches
130 nrows = self.HH.shape[0]
131 ncols = self.HH.shape[1]
133 nrows_patch, ncols_patch = self.patch_size
134 row_stride, col_stride = self.patch_stride
136 self.nsamples_per_rows = (nrows - nrows_patch) // row_stride + 1
137 self.nsamples_per_cols = (ncols - ncols_patch) // col_stride + 1
139 def __len__(self) -> int:
140 """
141 Returns the total number of patches in the whole image.
143 Returns:
144 the total number of patches in the dataset
145 """
146 return self.nsamples_per_rows * self.nsamples_per_cols
148 def __getitem__(self, idx) -> Tuple[Any, Any]:
149 """
150 Returns the indexes patch.
152 Arguments:
153 idx (int): Index
155 Returns:
156 tuple: (patch, labels) where patch contains the 3 complex valued polarization HH, HV, VV and labels contains the aligned semantic labels
157 """
158 row_stride, col_stride = self.patch_stride
159 start_row = (idx // self.nsamples_per_cols) * row_stride
160 start_col = (idx % self.nsamples_per_cols) * col_stride
161 num_rows, num_cols = self.patch_size
162 patches = [
163 patch[
164 start_row : (start_row + num_rows), start_col : (start_col + num_cols)
165 ]
166 for patch in [self.HH, self.HV, self.VV]
167 ]
168 patches = np.stack(patches)
169 if self.transform is not None:
170 patches = self.transform(patches)
172 labels = self.labels[
173 start_row : (start_row + num_rows), start_col : (start_col + num_cols)
174 ]
176 return patches, labels