Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/atrnet_star/dataset.py: 0%
98 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 05:19 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-26 05:19 +0000
1# MIT License
3# Copyright (c) 2025 Rodolphe Durand
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23import logging
24import pathlib
25from typing import Callable, Literal, Optional
27import scipy
28from torch.utils.data import Dataset
30from .parse_xml import xml_to_dict
33def gather_ATRNetSTAR_datafiles(
34 rootdir: pathlib.Path,
35 split: Literal["train", "test", "all"],
36) -> list[str]:
37 """
38 This function gathers all the ATRNet-STAR datafiles from the specified split in the root directory
40 It returns a list of the paths of the samples (without the extension).
41 """
43 data_files = []
45 if split != "all":
46 split_dir = rootdir / split
47 else:
48 split_dir = rootdir
49 logging.debug(f"Looking for all samples in {split_dir}")
51 # only look for data files (avoid duplicates)
52 for filename in split_dir.glob("**/*.mat"):
53 if not filename.is_file():
54 continue
56 # strip file of the .xml or .mat extension
57 sample_name = str(filename.with_suffix(""))
59 # add sample name to the list of known samples
60 data_files.append(sample_name)
62 return data_files
65class ATRNetSTARSample:
66 """
67 This class implements a sample from the ATRNet-STAR dataset.
68 Only slant-range quad polarization complex images are supported.
70 The extracted complex image is stored in the `data` attribute.
72 The data is a dictionnary, keys being the polarization ('HH', 'HV', 'VH' or 'VV')
73 and values a numpy array of the corresponding complex image.
75 The image annotations are stored in the `annotation` attribute. It contains all the fields
76 of the XML annotation, in a dictionnary.
78 Arguments:
79 sample_name (str): the name of the file to load WITHOUT the .mat or .xml extension
80 transform (Callable): (optional) A transform to apply to the data
81 """
83 DATA_FILETYPE = ".mat"
84 HEADER_FILETYPE = ".xml"
86 POL_KEY_TRANSLATION_DICT = {
87 "HH": "data_hh",
88 "HV": "data_hv",
89 "VH": "data_vh",
90 "VV": "data_vv",
91 }
93 def __init__(self, sample_name: str, transform: Optional[Callable] = None):
94 self._annotation = xml_to_dict(sample_name + self.HEADER_FILETYPE)
95 self._data = {}
96 self.transform = transform
98 image = scipy.io.loadmat(sample_name + self.DATA_FILETYPE)
99 for pol in self.POL_KEY_TRANSLATION_DICT.keys():
100 self._data[pol] = image[self.POL_KEY_TRANSLATION_DICT[pol]]
102 @property
103 def data(self):
104 if self.transform is not None:
105 return self.transform(self._data)
106 return self._data
108 @property
109 def annotation(self):
110 return self._annotation
113class ATRNetSTAR(Dataset):
114 """
115 Implements a PyTorch Dataset for the ATRNet-STAR dataset presented in :
117 Yongxiang Liu, Weijie Li, Li Liu, Jie Zhou, Bowen Peng, Yafei Song, Xuying Xiong, Wei Yang,
118 Tianpeng Liu, Zhen Liu, & Xiang Li. (2025).
119 ATRNet-STAR: A Large Dataset and Benchmark Towards Remote Sensing Object Recognition in the Wild.
121 Only slant-range quad polarization complex images are supported.
123 The dataset is composed of pre-defined benchmarks (see paper).
124 Dowloading them automatically is possible, but a Hugging Face authentification token
125 is needed as being logged in is required for this dataset.
127 Warning : samples are ordered by type, shuffling them is recommended.
129 Arguments:
130 root_dir (str): The root directory in which the different benchmarks are placed.
131 Will be created if it does not exist.
132 benchmark (str): (optional) Chosen benchmark. If not specified, SOC_40 (entire dataset) will be used as default.
133 split (str): (optional) Chosen split ('train', 'test' or 'all' for both). Those are pre-defined by the dataset. Default: 'all'
134 download (bool): (optional) Whether or not to download the dataset if it is not found. Default: False
135 class_level (str): (optional) The level of precision chosen for the class attributed to a sample.
136 Either 'category', 'class' or 'type'. Default: 'type' (the finest granularity)
137 get_annotations (bool): (optional) If `False`, a dataset item will be a tuple (`sample`, `target class`) (default).
138 If `True`, the entire sample annotation
139 dictionnary will also be returned: (`sample`, `target class`, `annotation dict`).
140 transform (Callable): (optional) A transform to apply to the data
141 """
143 # Hugging Face repository constants
144 HF_REPO_ID = "waterdisappear/ATRNet-STAR"
145 HF_BENCHMARKS_DIR_PATH = pathlib.Path("Slant_Range/complex_float_quad/")
147 # SOC_40classes is the complete original dataset, with the train/test splits defined
148 BENCHMARKS = [
149 "SOC_40classes",
150 "EOC_azimuth",
151 "EOC_band",
152 "EOC_depression",
153 "EOC_scene",
154 ]
156 _ALLOWED_BENCHMARKS = BENCHMARKS + ["SOC_40"]
157 # prettier logs later
158 _ALLOWED_BENCHMARKS.sort(reverse=True)
160 # EOC_polarization consists of training using one polarization and testing using another
161 # (to implement ? or leave it to the user ?)
162 #
163 # SOC_50 mixes the MSTAR dataset with a similar amount of samples from ATRNet-STAR.
164 # This should probably be done manually by the user
165 #
167 ### class names for all levels
169 CATEGORIES = ["Car", "Speacial", "Truck", "Bus"]
171 CLASSES = [
172 "Large_Car",
173 "Medium_SUV",
174 "Compact_SUV",
175 "Mini_Car",
176 "Medium_Car",
177 "ECV",
178 "Ambulance",
179 "Road_Roller",
180 "Shovel_Loader",
181 "Light_DT",
182 "Pickup",
183 "Mixer_Truck",
184 "Heavy_DT",
185 "Medium_TT",
186 "Light_PV",
187 "Heavy_FT",
188 "Forklift",
189 "Heavy_ST",
190 "Small_Bus",
191 "Medium_Bus",
192 "Large_Bus",
193 ]
195 TYPES = [
196 "Great_Wall_Voleex_C50",
197 "Hongqi_h5",
198 "Hongqi_CA7180A3E",
199 "Chang'an_CS75_Plus",
200 "Chevrolet_Blazer_1998",
201 "Changfeng_Cheetah_CFA6473C",
202 "Jeep_Patriot",
203 "Mitsubishi_Outlander_2003",
204 "Lincoln_MKC",
205 "Hawtai_EV160B",
206 "Chery_qq3",
207 "Buick_Excelle_GT",
208 "Chery_Arrizo 5",
209 "Lveco_Proud_2009",
210 "JINBEI_SY5033XJH",
211 "Changlin_8228-5",
212 "SDLG_ZL40F",
213 "Foton_BJ1045V9JB5-54",
214 "FAW_Jiabao_T51",
215 "WAW_Aochi_1800",
216 "Huanghai_N1",
217 "Great_Wall_poer",
218 "CNHTC_HOWO",
219 "Dongfeng_Tianjin_DFH2200B",
220 "WAW_Aochi_Hongrui",
221 "Dongfeng_Duolika",
222 "JAC_Junling",
223 "FAW_J6P",
224 "SHACMAN_DeLong_M3000",
225 "Hyundai_HLF25_II",
226 "Dongfeng_Tianjin_KR230",
227 "SHACMAN_DeLong_X3000",
228 "Wuling_Rongguang_V",
229 "Buick_GL8",
230 "Chang'an_Starlight_4500",
231 "Dongfeng_Forthing_Lingzhi",
232 "Yangzi_YZK6590XCA",
233 "Dongfeng_EQ6608LTV",
234 "MAXUS_V80",
235 "Yutong_ZK6120HY1",
236 ]
238 def __init__(
239 self,
240 root_dir: str,
241 benchmark: Optional[str] = None,
242 split: Literal["train", "test", "all"] = "all",
243 download: bool = False,
244 class_level: Literal["type", "class", "category"] = "type",
245 get_annotations: bool = False,
246 transform: Optional[Callable] = None,
247 ):
248 super().__init__()
249 self.root_dir = pathlib.Path(root_dir)
250 self.split = split
251 self.class_level = class_level
252 self.download = download
253 self.get_annotations = get_annotations
254 self.transform = transform
256 if benchmark is None:
257 # if no benchmark is given, default behavior should be to use the entire dataset (SOC_40)
258 logging.info(
259 "No benchmark was specified. SOC_40 (full dataset) will be used."
260 )
261 benchmark = "SOC_40classes"
263 self.benchmark = benchmark
264 self._verify_inputs()
266 if self.benchmark == "SOC_40":
267 # allow use of the name given to the benchmark in the paper instead of the file
268 # name actually used in their repository
269 # (more consistent with the rest of their file naming)
270 self.benchmark = "SOC_40classes"
272 self.benchmark_path = self.root_dir / self.benchmark
274 if not self.benchmark_path.exists():
275 if not download:
276 raise RuntimeError(
277 f"{self.benchmark} benchmark not found. You can use download=True to download it"
278 )
279 else:
280 self._download_dataset()
282 # gather samples
283 self.datafiles = gather_ATRNetSTAR_datafiles(
284 rootdir=self.benchmark_path, split=self.split
285 )
286 logging.debug(f"Found {len(self.datafiles)} samples.")
288 def _verify_inputs(self) -> None:
289 """Verify inputs are valid"""
290 if self.class_level not in ["type", "class", "category"]:
291 raise ValueError(
292 f"Unexpected class_level value. Got {self.class_level} instead of 'type', 'class' or 'category'."
293 )
295 if self.benchmark not in self._ALLOWED_BENCHMARKS:
296 benchmarks_with_quotes = [f"'{b}'" for b in self._ALLOWED_BENCHMARKS]
297 raise ValueError(
298 f"Unknown benchmark. Should be one of {', '.join(benchmarks_with_quotes)} or None"
299 )
301 if self.split not in ["train", "test", "all"]:
302 raise ValueError(
303 f"Unexpected split value. Got {self.split} instead of 'train', 'test' or 'all'."
304 )
306 def _download_dataset(self) -> None:
307 """
308 Downloads the specified benchmark.
309 Will be placed in a directory named like the benchmark, in root_dir
310 """
311 from .download import check_7z, download_benchmark
313 check_7z()
314 download_benchmark(
315 benchmark=self.benchmark,
316 root_dir=self.root_dir,
317 hf_repo_id=self.HF_REPO_ID,
318 hf_benchmark_path=self.HF_BENCHMARKS_DIR_PATH,
319 )
321 @property
322 def classes(self) -> list[str]:
323 """
324 Get the names of all classes at class_level,
325 """
326 if self.class_level == "category":
327 return self.CATEGORIES
329 elif self.class_level == "class":
330 return self.CLASSES
332 elif self.class_level == "type":
333 return self.TYPES
335 else:
336 raise ValueError(
337 f"Unexpected class_level value. Got {self.class_level} instead of type, class or category."
338 )
340 def __len__(self) -> int:
341 return len(self.datafiles)
343 def __getitem__(self, index: int) -> tuple:
344 """
345 Returns the sample at the given index. Applies the transform
346 if provided. If `get_annotations` is True, also return the entire annotation dict.
347 The return type depends on the transform applied and whether or not get_annotations is True
349 Arguments:
350 index : index of the sample to return
352 Returns:
353 data : the sample
354 class_idx : the index of the class in the classes list
355 (Optional) annotation : the annotation dict of the sample. Only if `get_annotations` is True.
357 """
358 sample_name = self.datafiles[index]
359 sample = ATRNetSTARSample(sample_name=sample_name, transform=self.transform)
360 class_idx = self.classes.index(sample.annotation["object"][self.class_level])
362 if self.get_annotations:
363 return (
364 sample.data,
365 class_idx,
366 sample.annotation,
367 )
369 return sample.data, class_idx