Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/atrnet_star/dataset.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-26 05:19 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Rodolphe Durand 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23import logging 

24import pathlib 

25from typing import Callable, Literal, Optional 

26 

27import scipy 

28from torch.utils.data import Dataset 

29 

30from .parse_xml import xml_to_dict 

31 

32 

33def gather_ATRNetSTAR_datafiles( 

34 rootdir: pathlib.Path, 

35 split: Literal["train", "test", "all"], 

36) -> list[str]: 

37 """ 

38 This function gathers all the ATRNet-STAR datafiles from the specified split in the root directory 

39 

40 It returns a list of the paths of the samples (without the extension). 

41 """ 

42 

43 data_files = [] 

44 

45 if split != "all": 

46 split_dir = rootdir / split 

47 else: 

48 split_dir = rootdir 

49 logging.debug(f"Looking for all samples in {split_dir}") 

50 

51 # only look for data files (avoid duplicates) 

52 for filename in split_dir.glob("**/*.mat"): 

53 if not filename.is_file(): 

54 continue 

55 

56 # strip file of the .xml or .mat extension 

57 sample_name = str(filename.with_suffix("")) 

58 

59 # add sample name to the list of known samples 

60 data_files.append(sample_name) 

61 

62 return data_files 

63 

64 

65class ATRNetSTARSample: 

66 """ 

67 This class implements a sample from the ATRNet-STAR dataset. 

68 Only slant-range quad polarization complex images are supported. 

69 

70 The extracted complex image is stored in the `data` attribute. 

71 

72 The data is a dictionnary, keys being the polarization ('HH', 'HV', 'VH' or 'VV') 

73 and values a numpy array of the corresponding complex image. 

74 

75 The image annotations are stored in the `annotation` attribute. It contains all the fields 

76 of the XML annotation, in a dictionnary. 

77 

78 Arguments: 

79 sample_name (str): the name of the file to load WITHOUT the .mat or .xml extension 

80 transform (Callable): (optional) A transform to apply to the data 

81 """ 

82 

83 DATA_FILETYPE = ".mat" 

84 HEADER_FILETYPE = ".xml" 

85 

86 POL_KEY_TRANSLATION_DICT = { 

87 "HH": "data_hh", 

88 "HV": "data_hv", 

89 "VH": "data_vh", 

90 "VV": "data_vv", 

91 } 

92 

93 def __init__(self, sample_name: str, transform: Optional[Callable] = None): 

94 self._annotation = xml_to_dict(sample_name + self.HEADER_FILETYPE) 

95 self._data = {} 

96 self.transform = transform 

97 

98 image = scipy.io.loadmat(sample_name + self.DATA_FILETYPE) 

99 for pol in self.POL_KEY_TRANSLATION_DICT.keys(): 

100 self._data[pol] = image[self.POL_KEY_TRANSLATION_DICT[pol]] 

101 

102 @property 

103 def data(self): 

104 if self.transform is not None: 

105 return self.transform(self._data) 

106 return self._data 

107 

108 @property 

109 def annotation(self): 

110 return self._annotation 

111 

112 

113class ATRNetSTAR(Dataset): 

114 """ 

115 Implements a PyTorch Dataset for the ATRNet-STAR dataset presented in : 

116 

117 Yongxiang Liu, Weijie Li, Li Liu, Jie Zhou, Bowen Peng, Yafei Song, Xuying Xiong, Wei Yang, 

118 Tianpeng Liu, Zhen Liu, & Xiang Li. (2025). 

119 ATRNet-STAR: A Large Dataset and Benchmark Towards Remote Sensing Object Recognition in the Wild. 

120 

121 Only slant-range quad polarization complex images are supported. 

122 

123 The dataset is composed of pre-defined benchmarks (see paper). 

124 Dowloading them automatically is possible, but a Hugging Face authentification token 

125 is needed as being logged in is required for this dataset. 

126 

127 Warning : samples are ordered by type, shuffling them is recommended. 

128 

129 Arguments: 

130 root_dir (str): The root directory in which the different benchmarks are placed. 

131 Will be created if it does not exist. 

132 benchmark (str): (optional) Chosen benchmark. If not specified, SOC_40 (entire dataset) will be used as default. 

133 split (str): (optional) Chosen split ('train', 'test' or 'all' for both). Those are pre-defined by the dataset. Default: 'all' 

134 download (bool): (optional) Whether or not to download the dataset if it is not found. Default: False 

135 class_level (str): (optional) The level of precision chosen for the class attributed to a sample. 

136 Either 'category', 'class' or 'type'. Default: 'type' (the finest granularity) 

137 get_annotations (bool): (optional) If `False`, a dataset item will be a tuple (`sample`, `target class`) (default). 

138 If `True`, the entire sample annotation 

139 dictionnary will also be returned: (`sample`, `target class`, `annotation dict`). 

140 transform (Callable): (optional) A transform to apply to the data 

141 """ 

142 

143 # Hugging Face repository constants 

144 HF_REPO_ID = "waterdisappear/ATRNet-STAR" 

145 HF_BENCHMARKS_DIR_PATH = pathlib.Path("Slant_Range/complex_float_quad/") 

146 

147 # SOC_40classes is the complete original dataset, with the train/test splits defined 

148 BENCHMARKS = [ 

149 "SOC_40classes", 

150 "EOC_azimuth", 

151 "EOC_band", 

152 "EOC_depression", 

153 "EOC_scene", 

154 ] 

155 

156 _ALLOWED_BENCHMARKS = BENCHMARKS + ["SOC_40"] 

157 # prettier logs later 

158 _ALLOWED_BENCHMARKS.sort(reverse=True) 

159 

160 # EOC_polarization consists of training using one polarization and testing using another 

161 # (to implement ? or leave it to the user ?) 

162 # 

163 # SOC_50 mixes the MSTAR dataset with a similar amount of samples from ATRNet-STAR. 

164 # This should probably be done manually by the user 

165 # 

166 

167 ### class names for all levels 

168 

169 CATEGORIES = ["Car", "Speacial", "Truck", "Bus"] 

170 

171 CLASSES = [ 

172 "Large_Car", 

173 "Medium_SUV", 

174 "Compact_SUV", 

175 "Mini_Car", 

176 "Medium_Car", 

177 "ECV", 

178 "Ambulance", 

179 "Road_Roller", 

180 "Shovel_Loader", 

181 "Light_DT", 

182 "Pickup", 

183 "Mixer_Truck", 

184 "Heavy_DT", 

185 "Medium_TT", 

186 "Light_PV", 

187 "Heavy_FT", 

188 "Forklift", 

189 "Heavy_ST", 

190 "Small_Bus", 

191 "Medium_Bus", 

192 "Large_Bus", 

193 ] 

194 

195 TYPES = [ 

196 "Great_Wall_Voleex_C50", 

197 "Hongqi_h5", 

198 "Hongqi_CA7180A3E", 

199 "Chang'an_CS75_Plus", 

200 "Chevrolet_Blazer_1998", 

201 "Changfeng_Cheetah_CFA6473C", 

202 "Jeep_Patriot", 

203 "Mitsubishi_Outlander_2003", 

204 "Lincoln_MKC", 

205 "Hawtai_EV160B", 

206 "Chery_qq3", 

207 "Buick_Excelle_GT", 

208 "Chery_Arrizo 5", 

209 "Lveco_Proud_2009", 

210 "JINBEI_SY5033XJH", 

211 "Changlin_8228-5", 

212 "SDLG_ZL40F", 

213 "Foton_BJ1045V9JB5-54", 

214 "FAW_Jiabao_T51", 

215 "WAW_Aochi_1800", 

216 "Huanghai_N1", 

217 "Great_Wall_poer", 

218 "CNHTC_HOWO", 

219 "Dongfeng_Tianjin_DFH2200B", 

220 "WAW_Aochi_Hongrui", 

221 "Dongfeng_Duolika", 

222 "JAC_Junling", 

223 "FAW_J6P", 

224 "SHACMAN_DeLong_M3000", 

225 "Hyundai_HLF25_II", 

226 "Dongfeng_Tianjin_KR230", 

227 "SHACMAN_DeLong_X3000", 

228 "Wuling_Rongguang_V", 

229 "Buick_GL8", 

230 "Chang'an_Starlight_4500", 

231 "Dongfeng_Forthing_Lingzhi", 

232 "Yangzi_YZK6590XCA", 

233 "Dongfeng_EQ6608LTV", 

234 "MAXUS_V80", 

235 "Yutong_ZK6120HY1", 

236 ] 

237 

238 def __init__( 

239 self, 

240 root_dir: str, 

241 benchmark: Optional[str] = None, 

242 split: Literal["train", "test", "all"] = "all", 

243 download: bool = False, 

244 class_level: Literal["type", "class", "category"] = "type", 

245 get_annotations: bool = False, 

246 transform: Optional[Callable] = None, 

247 ): 

248 super().__init__() 

249 self.root_dir = pathlib.Path(root_dir) 

250 self.split = split 

251 self.class_level = class_level 

252 self.download = download 

253 self.get_annotations = get_annotations 

254 self.transform = transform 

255 

256 if benchmark is None: 

257 # if no benchmark is given, default behavior should be to use the entire dataset (SOC_40) 

258 logging.info( 

259 "No benchmark was specified. SOC_40 (full dataset) will be used." 

260 ) 

261 benchmark = "SOC_40classes" 

262 

263 self.benchmark = benchmark 

264 self._verify_inputs() 

265 

266 if self.benchmark == "SOC_40": 

267 # allow use of the name given to the benchmark in the paper instead of the file 

268 # name actually used in their repository 

269 # (more consistent with the rest of their file naming) 

270 self.benchmark = "SOC_40classes" 

271 

272 self.benchmark_path = self.root_dir / self.benchmark 

273 

274 if not self.benchmark_path.exists(): 

275 if not download: 

276 raise RuntimeError( 

277 f"{self.benchmark} benchmark not found. You can use download=True to download it" 

278 ) 

279 else: 

280 self._download_dataset() 

281 

282 # gather samples 

283 self.datafiles = gather_ATRNetSTAR_datafiles( 

284 rootdir=self.benchmark_path, split=self.split 

285 ) 

286 logging.debug(f"Found {len(self.datafiles)} samples.") 

287 

288 def _verify_inputs(self) -> None: 

289 """Verify inputs are valid""" 

290 if self.class_level not in ["type", "class", "category"]: 

291 raise ValueError( 

292 f"Unexpected class_level value. Got {self.class_level} instead of 'type', 'class' or 'category'." 

293 ) 

294 

295 if self.benchmark not in self._ALLOWED_BENCHMARKS: 

296 benchmarks_with_quotes = [f"'{b}'" for b in self._ALLOWED_BENCHMARKS] 

297 raise ValueError( 

298 f"Unknown benchmark. Should be one of {', '.join(benchmarks_with_quotes)} or None" 

299 ) 

300 

301 if self.split not in ["train", "test", "all"]: 

302 raise ValueError( 

303 f"Unexpected split value. Got {self.split} instead of 'train', 'test' or 'all'." 

304 ) 

305 

306 def _download_dataset(self) -> None: 

307 """ 

308 Downloads the specified benchmark. 

309 Will be placed in a directory named like the benchmark, in root_dir 

310 """ 

311 from .download import check_7z, download_benchmark 

312 

313 check_7z() 

314 download_benchmark( 

315 benchmark=self.benchmark, 

316 root_dir=self.root_dir, 

317 hf_repo_id=self.HF_REPO_ID, 

318 hf_benchmark_path=self.HF_BENCHMARKS_DIR_PATH, 

319 ) 

320 

321 @property 

322 def classes(self) -> list[str]: 

323 """ 

324 Get the names of all classes at class_level, 

325 """ 

326 if self.class_level == "category": 

327 return self.CATEGORIES 

328 

329 elif self.class_level == "class": 

330 return self.CLASSES 

331 

332 elif self.class_level == "type": 

333 return self.TYPES 

334 

335 else: 

336 raise ValueError( 

337 f"Unexpected class_level value. Got {self.class_level} instead of type, class or category." 

338 ) 

339 

340 def __len__(self) -> int: 

341 return len(self.datafiles) 

342 

343 def __getitem__(self, index: int) -> tuple: 

344 """ 

345 Returns the sample at the given index. Applies the transform 

346 if provided. If `get_annotations` is True, also return the entire annotation dict. 

347 The return type depends on the transform applied and whether or not get_annotations is True 

348 

349 Arguments: 

350 index : index of the sample to return 

351 

352 Returns: 

353 data : the sample 

354 class_idx : the index of the class in the classes list 

355 (Optional) annotation : the annotation dict of the sample. Only if `get_annotations` is True. 

356 

357 """ 

358 sample_name = self.datafiles[index] 

359 sample = ATRNetSTARSample(sample_name=sample_name, transform=self.transform) 

360 class_idx = self.classes.index(sample.annotation["object"][self.class_level]) 

361 

362 if self.get_annotations: 

363 return ( 

364 sample.data, 

365 class_idx, 

366 sample.annotation, 

367 ) 

368 

369 return sample.data, class_idx