Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/atrnet_star/dataset.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.10.4, created at 2025-08-21 11:56 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Rodolphe Durand 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23import logging 

24import os 

25import pathlib 

26from typing import Callable, Literal, Optional, Union 

27 

28import scipy 

29from torch.utils.data import Dataset 

30 

31from .parse_xml import xml_to_dict 

32 

33 

34def gather_ATRNetSTAR_datafiles( 

35 rootdir: pathlib.Path, 

36 split: str, 

37) -> list[str]: 

38 """ 

39 This function gathers all the ATRNet-STAR datafiles from the specified split in the root directory 

40 

41 It returns a list of the paths of the samples (without the extension). 

42 """ 

43 

44 data_files = [] 

45 

46 if split != "all": 

47 split_dir = rootdir / split 

48 else: 

49 split_dir = rootdir 

50 

51 if not split_dir.exists(): 

52 raise ValueError( 

53 f"{str(split_dir)} not found. Possible splits of benchmark {rootdir.name} are : {[f.name for f in os.scandir(rootdir) if f.is_dir()]}" 

54 ) 

55 

56 logging.debug(f"Looking for all samples in {split_dir}") 

57 

58 # only look for data files (avoid duplicates) 

59 for filename in split_dir.glob("**/*.mat"): 

60 if not filename.is_file(): 

61 continue 

62 

63 # strip file of the .xml or .mat extension 

64 sample_name = str(filename.with_suffix("")) 

65 

66 # add sample name to the list of known samples 

67 data_files.append(sample_name) 

68 

69 return data_files 

70 

71 

72class ATRNetSTARSample: 

73 """ 

74 This class implements a sample from the ATRNet-STAR dataset. 

75 Only slant-range quad polarization complex images are supported. 

76 

77 The extracted complex image is stored in the `data` attribute. 

78 

79 The data is a dictionnary, keys being the polarization ('HH', 'HV', 'VH' or 'VV') 

80 and values a numpy array of the corresponding complex image. 

81 

82 The image annotations are stored in the `annotation` attribute. It contains all the fields 

83 of the XML annotation, in a dictionnary. 

84 

85 Arguments: 

86 sample_name (str): the name of the file to load WITHOUT the .mat or .xml extension 

87 transform (Callable): (optional) A transform to apply to the data 

88 """ 

89 

90 DATA_FILETYPE = ".mat" 

91 HEADER_FILETYPE = ".xml" 

92 

93 POL_KEY_TRANSLATION_DICT = { 

94 "HH": "data_hh", 

95 "HV": "data_hv", 

96 "VH": "data_vh", 

97 "VV": "data_vv", 

98 } 

99 

100 def __init__(self, sample_name: str, transform: Optional[Callable] = None): 

101 self._annotation = xml_to_dict(sample_name + self.HEADER_FILETYPE) 

102 self._data = {} 

103 self.transform = transform 

104 

105 image = scipy.io.loadmat(sample_name + self.DATA_FILETYPE) 

106 for pol in self.POL_KEY_TRANSLATION_DICT.keys(): 

107 self._data[pol] = image[self.POL_KEY_TRANSLATION_DICT[pol]] 

108 

109 @property 

110 def data(self): 

111 if self.transform is not None: 

112 return self.transform(self._data) 

113 return self._data 

114 

115 @property 

116 def annotation(self): 

117 return self._annotation 

118 

119 

120class ATRNetSTAR(Dataset): 

121 """ 

122 Implements a PyTorch Dataset for the ATRNet-STAR dataset presented in : 

123 

124 Yongxiang Liu, Weijie Li, Li Liu, Jie Zhou, Bowen Peng, Yafei Song, Xuying Xiong, Wei Yang, 

125 Tianpeng Liu, Zhen Liu, & Xiang Li. (2025). 

126 ATRNet-STAR: A Large Dataset and Benchmark Towards Remote Sensing Object Recognition in the Wild. 

127 

128 Only slant-range quad polarization complex images are supported. 

129 

130 The dataset is composed of pre-defined benchmarks (see paper). 

131 Dowloading them automatically is possible, but a Hugging Face authentification token 

132 is needed as being logged in is required for this dataset. 

133 

134 Warning : samples are ordered by type, shuffling them is recommended. 

135 

136 Arguments: 

137 root_dir (str): The root directory in which the different benchmarks are placed. 

138 Will be created if it does not exist. 

139 benchmark (str): (optional) Chosen benchmark. If not specified, SOC_40 (entire dataset) will be used as default. 

140 split (str): Chosen split ('train', 'test', ... or 'all' for all benchmark samples). 

141 Those are pre-defined by the dataset for each benchmark. 

142 download (bool): (optional) Whether or not to download the dataset if it is not found. Default: False 

143 class_level (str): (optional) The level of precision chosen for the class attributed to a sample. 

144 Either 'category', 'class' or 'type'. Default: 'type' (the finest granularity) 

145 get_annotations (bool): (optional) If `False`, a dataset item will be a tuple (`sample`, `target class`) (default). 

146 If `True`, the entire sample annotation 

147 dictionnary will also be returned: (`sample`, `target class`, `annotation dict`). 

148 transform (Callable): (optional) A transform to apply to the data 

149 """ 

150 

151 # Hugging Face repository constants 

152 HF_REPO_ID = "waterdisappear/ATRNet-STAR" 

153 HF_BENCHMARKS_DIR_PATH = pathlib.Path("Slant_Range/complex_float_quad/") 

154 

155 # SOC_40classes is the complete original dataset, with the train/test splits defined 

156 BENCHMARKS = [ 

157 "SOC_40classes", 

158 "EOC_azimuth", 

159 "EOC_band", 

160 "EOC_depression", 

161 "EOC_scene", 

162 ] 

163 

164 _ALLOWED_BENCHMARKS = BENCHMARKS + ["SOC_40"] 

165 # prettier logs later 

166 _ALLOWED_BENCHMARKS.sort(reverse=True) 

167 

168 # EOC_polarization consists of training using one polarization and testing using another 

169 # (to implement ? or leave it to the user ?) 

170 # 

171 # SOC_50 mixes the MSTAR dataset with a similar amount of samples from ATRNet-STAR. 

172 # This should probably be done manually by the user 

173 # 

174 

175 ### class names for all levels 

176 

177 CATEGORIES = ["Car", "Speacial", "Truck", "Bus"] 

178 

179 CLASSES = [ 

180 "Large_Car", 

181 "Medium_SUV", 

182 "Compact_SUV", 

183 "Mini_Car", 

184 "Medium_Car", 

185 "ECV", 

186 "Ambulance", 

187 "Road_Roller", 

188 "Shovel_Loader", 

189 "Light_DT", 

190 "Pickup", 

191 "Mixer_Truck", 

192 "Heavy_DT", 

193 "Medium_TT", 

194 "Light_PV", 

195 "Heavy_FT", 

196 "Forklift", 

197 "Heavy_ST", 

198 "Small_Bus", 

199 "Medium_Bus", 

200 "Large_Bus", 

201 ] 

202 

203 TYPES = [ 

204 "Great_Wall_Voleex_C50", 

205 "Hongqi_h5", 

206 "Hongqi_CA7180A3E", 

207 "Chang'an_CS75_Plus", 

208 "Chevrolet_Blazer_1998", 

209 "Changfeng_Cheetah_CFA6473C", 

210 "Jeep_Patriot", 

211 "Mitsubishi_Outlander_2003", 

212 "Lincoln_MKC", 

213 "Hawtai_EV160B", 

214 "Chery_qq3", 

215 "Buick_Excelle_GT", 

216 "Chery_Arrizo 5", 

217 "Lveco_Proud_2009", 

218 "JINBEI_SY5033XJH", 

219 "Changlin_8228-5", 

220 "SDLG_ZL40F", 

221 "Foton_BJ1045V9JB5-54", 

222 "FAW_Jiabao_T51", 

223 "WAW_Aochi_1800", 

224 "Huanghai_N1", 

225 "Great_Wall_poer", 

226 "CNHTC_HOWO", 

227 "Dongfeng_Tianjin_DFH2200B", 

228 "WAW_Aochi_Hongrui", 

229 "Dongfeng_Duolika", 

230 "JAC_Junling", 

231 "FAW_J6P", 

232 "SHACMAN_DeLong_M3000", 

233 "Hyundai_HLF25_II", 

234 "Dongfeng_Tianjin_KR230", 

235 "SHACMAN_DeLong_X3000", 

236 "Wuling_Rongguang_V", 

237 "Buick_GL8", 

238 "Chang'an_Starlight_4500", 

239 "Dongfeng_Forthing_Lingzhi", 

240 "Yangzi_YZK6590XCA", 

241 "Dongfeng_EQ6608LTV", 

242 "MAXUS_V80", 

243 "Yutong_ZK6120HY1", 

244 ] 

245 

246 def __init__( 

247 self, 

248 root_dir: str, 

249 split: Union[Literal["train", "test", "all"], str], 

250 benchmark: Optional[str] = None, 

251 download: bool = False, 

252 class_level: Literal["type", "class", "category"] = "type", 

253 get_annotations: bool = False, 

254 transform: Optional[Callable] = None, 

255 ): 

256 super().__init__() 

257 self.root_dir = pathlib.Path(root_dir) 

258 self.split = split 

259 self.class_level = class_level 

260 self.download = download 

261 self.get_annotations = get_annotations 

262 self.transform = transform 

263 

264 if benchmark is None: 

265 # if no benchmark is given, default behavior should be to use the entire dataset (SOC_40) 

266 logging.info( 

267 "No benchmark was specified. SOC_40 (full dataset) will be used." 

268 ) 

269 benchmark = "SOC_40classes" 

270 

271 self.benchmark = benchmark 

272 self._verify_inputs() 

273 

274 if self.benchmark == "SOC_40": 

275 # allow use of the name given to the benchmark in the paper instead of the file 

276 # name actually used in their repository 

277 # (more consistent with the rest of their file naming) 

278 self.benchmark = "SOC_40classes" 

279 

280 self.benchmark_path = self.root_dir / self.benchmark 

281 

282 if not self.benchmark_path.exists(): 

283 if not download: 

284 raise RuntimeError( 

285 f"{self.benchmark} benchmark not found. You can use download=True to download it" 

286 ) 

287 else: 

288 self._download_dataset() 

289 

290 # gather samples 

291 self.datafiles = gather_ATRNetSTAR_datafiles( 

292 rootdir=self.benchmark_path, split=self.split 

293 ) 

294 logging.debug(f"Found {len(self.datafiles)} samples.") 

295 

296 def _verify_inputs(self) -> None: 

297 """Verify inputs are valid""" 

298 if self.class_level not in ["type", "class", "category"]: 

299 raise ValueError( 

300 f"Unexpected class_level value. Got {self.class_level} instead of 'type', 'class' or 'category'." 

301 ) 

302 

303 if self.benchmark not in self._ALLOWED_BENCHMARKS: 

304 benchmarks_with_quotes = [f"'{b}'" for b in self._ALLOWED_BENCHMARKS] 

305 raise ValueError( 

306 f"Unknown benchmark. Should be one of {', '.join(benchmarks_with_quotes)} or None" 

307 ) 

308 

309 def _download_dataset(self) -> None: 

310 """ 

311 Downloads the specified benchmark. 

312 Will be placed in a directory named like the benchmark, in root_dir 

313 """ 

314 from .download import check_7z, download_benchmark 

315 

316 check_7z() 

317 download_benchmark( 

318 benchmark=self.benchmark, 

319 root_dir=self.root_dir, 

320 hf_repo_id=self.HF_REPO_ID, 

321 hf_benchmark_path=self.HF_BENCHMARKS_DIR_PATH, 

322 ) 

323 

324 @property 

325 def classes(self) -> list[str]: 

326 """ 

327 Get the names of all classes at class_level, 

328 """ 

329 if self.class_level == "category": 

330 return self.CATEGORIES 

331 

332 elif self.class_level == "class": 

333 return self.CLASSES 

334 

335 elif self.class_level == "type": 

336 return self.TYPES 

337 

338 else: 

339 raise ValueError( 

340 f"Unexpected class_level value. Got {self.class_level} instead of type, class or category." 

341 ) 

342 

343 def __len__(self) -> int: 

344 return len(self.datafiles) 

345 

346 def __getitem__(self, index: int) -> tuple: 

347 """ 

348 Returns the sample at the given index. Applies the transform 

349 if provided. If `get_annotations` is True, also return the entire annotation dict. 

350 The return type depends on the transform applied and whether or not get_annotations is True 

351 

352 Arguments: 

353 index : index of the sample to return 

354 

355 Returns: 

356 data : the sample 

357 class_idx : the index of the class in the classes list 

358 (Optional) annotation : the annotation dict of the sample. Only if `get_annotations` is True. 

359 

360 """ 

361 sample_name = self.datafiles[index] 

362 sample = ATRNetSTARSample(sample_name=sample_name, transform=self.transform) 

363 class_idx = self.classes.index(sample.annotation["object"][self.class_level]) 

364 

365 if self.get_annotations: 

366 return ( 

367 sample.data, 

368 class_idx, 

369 sample.annotation, 

370 ) 

371 

372 return sample.data, class_idx