Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/atrnet_star/dataset.py: 0%
99 statements
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-21 11:56 +0000
« prev ^ index » next coverage.py v7.10.4, created at 2025-08-21 11:56 +0000
1# MIT License
3# Copyright (c) 2025 Rodolphe Durand
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23import logging
24import os
25import pathlib
26from typing import Callable, Literal, Optional, Union
28import scipy
29from torch.utils.data import Dataset
31from .parse_xml import xml_to_dict
34def gather_ATRNetSTAR_datafiles(
35 rootdir: pathlib.Path,
36 split: str,
37) -> list[str]:
38 """
39 This function gathers all the ATRNet-STAR datafiles from the specified split in the root directory
41 It returns a list of the paths of the samples (without the extension).
42 """
44 data_files = []
46 if split != "all":
47 split_dir = rootdir / split
48 else:
49 split_dir = rootdir
51 if not split_dir.exists():
52 raise ValueError(
53 f"{str(split_dir)} not found. Possible splits of benchmark {rootdir.name} are : {[f.name for f in os.scandir(rootdir) if f.is_dir()]}"
54 )
56 logging.debug(f"Looking for all samples in {split_dir}")
58 # only look for data files (avoid duplicates)
59 for filename in split_dir.glob("**/*.mat"):
60 if not filename.is_file():
61 continue
63 # strip file of the .xml or .mat extension
64 sample_name = str(filename.with_suffix(""))
66 # add sample name to the list of known samples
67 data_files.append(sample_name)
69 return data_files
72class ATRNetSTARSample:
73 """
74 This class implements a sample from the ATRNet-STAR dataset.
75 Only slant-range quad polarization complex images are supported.
77 The extracted complex image is stored in the `data` attribute.
79 The data is a dictionnary, keys being the polarization ('HH', 'HV', 'VH' or 'VV')
80 and values a numpy array of the corresponding complex image.
82 The image annotations are stored in the `annotation` attribute. It contains all the fields
83 of the XML annotation, in a dictionnary.
85 Arguments:
86 sample_name (str): the name of the file to load WITHOUT the .mat or .xml extension
87 transform (Callable): (optional) A transform to apply to the data
88 """
90 DATA_FILETYPE = ".mat"
91 HEADER_FILETYPE = ".xml"
93 POL_KEY_TRANSLATION_DICT = {
94 "HH": "data_hh",
95 "HV": "data_hv",
96 "VH": "data_vh",
97 "VV": "data_vv",
98 }
100 def __init__(self, sample_name: str, transform: Optional[Callable] = None):
101 self._annotation = xml_to_dict(sample_name + self.HEADER_FILETYPE)
102 self._data = {}
103 self.transform = transform
105 image = scipy.io.loadmat(sample_name + self.DATA_FILETYPE)
106 for pol in self.POL_KEY_TRANSLATION_DICT.keys():
107 self._data[pol] = image[self.POL_KEY_TRANSLATION_DICT[pol]]
109 @property
110 def data(self):
111 if self.transform is not None:
112 return self.transform(self._data)
113 return self._data
115 @property
116 def annotation(self):
117 return self._annotation
120class ATRNetSTAR(Dataset):
121 """
122 Implements a PyTorch Dataset for the ATRNet-STAR dataset presented in :
124 Yongxiang Liu, Weijie Li, Li Liu, Jie Zhou, Bowen Peng, Yafei Song, Xuying Xiong, Wei Yang,
125 Tianpeng Liu, Zhen Liu, & Xiang Li. (2025).
126 ATRNet-STAR: A Large Dataset and Benchmark Towards Remote Sensing Object Recognition in the Wild.
128 Only slant-range quad polarization complex images are supported.
130 The dataset is composed of pre-defined benchmarks (see paper).
131 Dowloading them automatically is possible, but a Hugging Face authentification token
132 is needed as being logged in is required for this dataset.
134 Warning : samples are ordered by type, shuffling them is recommended.
136 Arguments:
137 root_dir (str): The root directory in which the different benchmarks are placed.
138 Will be created if it does not exist.
139 benchmark (str): (optional) Chosen benchmark. If not specified, SOC_40 (entire dataset) will be used as default.
140 split (str): Chosen split ('train', 'test', ... or 'all' for all benchmark samples).
141 Those are pre-defined by the dataset for each benchmark.
142 download (bool): (optional) Whether or not to download the dataset if it is not found. Default: False
143 class_level (str): (optional) The level of precision chosen for the class attributed to a sample.
144 Either 'category', 'class' or 'type'. Default: 'type' (the finest granularity)
145 get_annotations (bool): (optional) If `False`, a dataset item will be a tuple (`sample`, `target class`) (default).
146 If `True`, the entire sample annotation
147 dictionnary will also be returned: (`sample`, `target class`, `annotation dict`).
148 transform (Callable): (optional) A transform to apply to the data
149 """
151 # Hugging Face repository constants
152 HF_REPO_ID = "waterdisappear/ATRNet-STAR"
153 HF_BENCHMARKS_DIR_PATH = pathlib.Path("Slant_Range/complex_float_quad/")
155 # SOC_40classes is the complete original dataset, with the train/test splits defined
156 BENCHMARKS = [
157 "SOC_40classes",
158 "EOC_azimuth",
159 "EOC_band",
160 "EOC_depression",
161 "EOC_scene",
162 ]
164 _ALLOWED_BENCHMARKS = BENCHMARKS + ["SOC_40"]
165 # prettier logs later
166 _ALLOWED_BENCHMARKS.sort(reverse=True)
168 # EOC_polarization consists of training using one polarization and testing using another
169 # (to implement ? or leave it to the user ?)
170 #
171 # SOC_50 mixes the MSTAR dataset with a similar amount of samples from ATRNet-STAR.
172 # This should probably be done manually by the user
173 #
175 ### class names for all levels
177 CATEGORIES = ["Car", "Speacial", "Truck", "Bus"]
179 CLASSES = [
180 "Large_Car",
181 "Medium_SUV",
182 "Compact_SUV",
183 "Mini_Car",
184 "Medium_Car",
185 "ECV",
186 "Ambulance",
187 "Road_Roller",
188 "Shovel_Loader",
189 "Light_DT",
190 "Pickup",
191 "Mixer_Truck",
192 "Heavy_DT",
193 "Medium_TT",
194 "Light_PV",
195 "Heavy_FT",
196 "Forklift",
197 "Heavy_ST",
198 "Small_Bus",
199 "Medium_Bus",
200 "Large_Bus",
201 ]
203 TYPES = [
204 "Great_Wall_Voleex_C50",
205 "Hongqi_h5",
206 "Hongqi_CA7180A3E",
207 "Chang'an_CS75_Plus",
208 "Chevrolet_Blazer_1998",
209 "Changfeng_Cheetah_CFA6473C",
210 "Jeep_Patriot",
211 "Mitsubishi_Outlander_2003",
212 "Lincoln_MKC",
213 "Hawtai_EV160B",
214 "Chery_qq3",
215 "Buick_Excelle_GT",
216 "Chery_Arrizo 5",
217 "Lveco_Proud_2009",
218 "JINBEI_SY5033XJH",
219 "Changlin_8228-5",
220 "SDLG_ZL40F",
221 "Foton_BJ1045V9JB5-54",
222 "FAW_Jiabao_T51",
223 "WAW_Aochi_1800",
224 "Huanghai_N1",
225 "Great_Wall_poer",
226 "CNHTC_HOWO",
227 "Dongfeng_Tianjin_DFH2200B",
228 "WAW_Aochi_Hongrui",
229 "Dongfeng_Duolika",
230 "JAC_Junling",
231 "FAW_J6P",
232 "SHACMAN_DeLong_M3000",
233 "Hyundai_HLF25_II",
234 "Dongfeng_Tianjin_KR230",
235 "SHACMAN_DeLong_X3000",
236 "Wuling_Rongguang_V",
237 "Buick_GL8",
238 "Chang'an_Starlight_4500",
239 "Dongfeng_Forthing_Lingzhi",
240 "Yangzi_YZK6590XCA",
241 "Dongfeng_EQ6608LTV",
242 "MAXUS_V80",
243 "Yutong_ZK6120HY1",
244 ]
246 def __init__(
247 self,
248 root_dir: str,
249 split: Union[Literal["train", "test", "all"], str],
250 benchmark: Optional[str] = None,
251 download: bool = False,
252 class_level: Literal["type", "class", "category"] = "type",
253 get_annotations: bool = False,
254 transform: Optional[Callable] = None,
255 ):
256 super().__init__()
257 self.root_dir = pathlib.Path(root_dir)
258 self.split = split
259 self.class_level = class_level
260 self.download = download
261 self.get_annotations = get_annotations
262 self.transform = transform
264 if benchmark is None:
265 # if no benchmark is given, default behavior should be to use the entire dataset (SOC_40)
266 logging.info(
267 "No benchmark was specified. SOC_40 (full dataset) will be used."
268 )
269 benchmark = "SOC_40classes"
271 self.benchmark = benchmark
272 self._verify_inputs()
274 if self.benchmark == "SOC_40":
275 # allow use of the name given to the benchmark in the paper instead of the file
276 # name actually used in their repository
277 # (more consistent with the rest of their file naming)
278 self.benchmark = "SOC_40classes"
280 self.benchmark_path = self.root_dir / self.benchmark
282 if not self.benchmark_path.exists():
283 if not download:
284 raise RuntimeError(
285 f"{self.benchmark} benchmark not found. You can use download=True to download it"
286 )
287 else:
288 self._download_dataset()
290 # gather samples
291 self.datafiles = gather_ATRNetSTAR_datafiles(
292 rootdir=self.benchmark_path, split=self.split
293 )
294 logging.debug(f"Found {len(self.datafiles)} samples.")
296 def _verify_inputs(self) -> None:
297 """Verify inputs are valid"""
298 if self.class_level not in ["type", "class", "category"]:
299 raise ValueError(
300 f"Unexpected class_level value. Got {self.class_level} instead of 'type', 'class' or 'category'."
301 )
303 if self.benchmark not in self._ALLOWED_BENCHMARKS:
304 benchmarks_with_quotes = [f"'{b}'" for b in self._ALLOWED_BENCHMARKS]
305 raise ValueError(
306 f"Unknown benchmark. Should be one of {', '.join(benchmarks_with_quotes)} or None"
307 )
309 def _download_dataset(self) -> None:
310 """
311 Downloads the specified benchmark.
312 Will be placed in a directory named like the benchmark, in root_dir
313 """
314 from .download import check_7z, download_benchmark
316 check_7z()
317 download_benchmark(
318 benchmark=self.benchmark,
319 root_dir=self.root_dir,
320 hf_repo_id=self.HF_REPO_ID,
321 hf_benchmark_path=self.HF_BENCHMARKS_DIR_PATH,
322 )
324 @property
325 def classes(self) -> list[str]:
326 """
327 Get the names of all classes at class_level,
328 """
329 if self.class_level == "category":
330 return self.CATEGORIES
332 elif self.class_level == "class":
333 return self.CLASSES
335 elif self.class_level == "type":
336 return self.TYPES
338 else:
339 raise ValueError(
340 f"Unexpected class_level value. Got {self.class_level} instead of type, class or category."
341 )
343 def __len__(self) -> int:
344 return len(self.datafiles)
346 def __getitem__(self, index: int) -> tuple:
347 """
348 Returns the sample at the given index. Applies the transform
349 if provided. If `get_annotations` is True, also return the entire annotation dict.
350 The return type depends on the transform applied and whether or not get_annotations is True
352 Arguments:
353 index : index of the sample to return
355 Returns:
356 data : the sample
357 class_idx : the index of the class in the classes list
358 (Optional) annotation : the annotation dict of the sample. Only if `get_annotations` is True.
360 """
361 sample_name = self.datafiles[index]
362 sample = ATRNetSTARSample(sample_name=sample_name, transform=self.transform)
363 class_idx = self.classes.index(sample.annotation["object"][self.class_level])
365 if self.get_annotations:
366 return (
367 sample.data,
368 class_idx,
369 sample.annotation,
370 )
372 return sample.data, class_idx