Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / datasets / mstar / dataset.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-12 14:36 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Tuple 

25import pathlib 

26import logging 

27import struct 

28 

29# External imports 

30import torch 

31from torch.utils.data import Dataset 

32import numpy as np 

33 

34 

35def parse_header(fh) -> dict: 

36 """ 

37 This function parses the PhoenixHeader from a file handle from the provided file handle. 

38 

39 It returns a dictionnary containing all the fields of the PhoenixHeader. 

40 

41 It raises an exception if the header is not valid, i.e. does not start with [PhoenixHeaderVer01.04] or [PhoenixHeaderVer01.05] 

42 """ 

43 parsed_fields = {} 

44 # There is one character at the very beginning of the file, before the header 

45 _ = fh.readline() 

46 start_line = fh.readline().strip() 

47 accepted_headers = ["[PhoenixHeaderVer01.04]", "[PhoenixHeaderVer01.05]"] 

48 if start_line not in accepted_headers: 

49 raise ValueError( 

50 f"Invalid header : {start_line}, expected one of {accepted_headers}" 

51 ) 

52 next_line = fh.readline().strip() 

53 while next_line != "[EndofPhoenixHeader]": 

54 items = next_line.split("=") 

55 key = items[0].strip() 

56 value = items[1].strip() 

57 parsed_fields[key] = value 

58 next_line = fh.readline().strip() 

59 

60 return parsed_fields 

61 

62 

63class MSTARSample: 

64 """ 

65 This class implements a sample from the MSTAR dataset. 

66 

67 The extracted complex image is stored in the data attribute. 

68 The data is a numpy array of shape (num_rows, num_cols, 1) 

69 

70 The header is stored in the header attribute. It contains all the fields 

71 of the PhoenixHeader. 

72 

73 Arguments: 

74 filename : the name of the file to load 

75 """ 

76 

77 def __init__(self, filename: str): 

78 self.filename = pathlib.Path(filename) 

79 

80 with open(self.filename, "r", errors="replace") as fh: 

81 # Read the header from the file 

82 self.header = parse_header(fh) 

83 

84 num_rows = int(self.header["NumberOfRows"]) 

85 num_cols = int(self.header["NumberOfColumns"]) 

86 phoenix_header_length = int(self.header["PhoenixHeaderLength"]) 

87 native_header_length = int(self.header["native_header_length"]) 

88 

89 fh.seek(phoenix_header_length + native_header_length) 

90 

91 sig_size = int(self.header["PhoenixSigSize"]) 

92 bytes_per_values = ( 

93 sig_size - phoenix_header_length - native_header_length 

94 ) // (2 * num_rows * num_cols) 

95 

96 if bytes_per_values == 4: 

97 # Read the data as float32 

98 pass 

99 elif bytes_per_values == 2: 

100 # Read the data as uint16 

101 pass 

102 else: 

103 raise ValueError( 

104 f"Unsupported number of bytes per value : {bytes_per_values}" 

105 ) 

106 

107 # Read the data from the file 

108 with open(self.filename, "rb") as fh: 

109 fh.seek(phoenix_header_length + native_header_length) 

110 

111 data_bytes = fh.read(num_rows * num_cols * bytes_per_values) 

112 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes) 

113 magnitudes = np.array(unpacked).reshape(num_rows, num_cols) 

114 

115 data_bytes = fh.read(num_rows * num_cols * bytes_per_values) 

116 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes) 

117 phases = np.array(unpacked).reshape(num_rows, num_cols) 

118 

119 self.data = magnitudes * np.exp(1j * phases) 

120 self.data = self.data[:, :, np.newaxis] 

121 

122 

123def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) -> dict: 

124 """ 

125 This function gathers all the MSTAR datafiles from the root directory 

126 It looks for files named HBxxxx that are data files (containing a PhoenixHeader). 

127 

128 The assigned target name is the name of the directory, or parent directory, at the target_name_depth level. 

129 """ 

130 

131 data_files = {} 

132 for filename in rootdir.glob("**/HB*"): 

133 if not filename.is_file(): 

134 continue 

135 

136 try: 

137 with open(filename, "r", errors="replace") as fh: 

138 _ = parse_header(fh) 

139 # sample = MSTARSample(filename) 

140 except Exception as e: 

141 logging.debug( 

142 f"The file {filename} failed to be loaded as a MSTAR sample: {e}" 

143 ) 

144 continue 

145 

146 target_name = filename.parts[-target_name_depth] 

147 if target_name not in data_files: 

148 data_files[target_name] = [] 

149 

150 logging.debug(f"Successfully parsed {filename} as a {target_name} sample.") 

151 data_files[target_name].append(filename) 

152 return data_files 

153 

154 

155class MSTARTargets(Dataset): 

156 """ 

157 This class implements a PyTorch Dataset for the MSTAR dataset. 

158 

159 The MSTAR dataset is composed of several sub-datasets. The datasets must 

160 be downloaded manually because they require authentication. 

161 

162 To download these datasets, you must register at the following address: https://www.sdms.afrl.af.mil/index.php?collection=mstar 

163 

164 This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following : 

165 

166 - MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants 

167 - MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants 

168 - MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed 

169 - MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed 

170 - MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY : 

171 https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=targets 

172 

173 Arguments: 

174 rootdir : str 

175 transform : the transform applied on the input complex valued array 

176 

177 Note: 

178 An example usage : 

179 

180 .. code-block:: python 

181 

182 import torchcvnn 

183 from torchcvnn.datasets import MSTARTargets 

184 

185 transform = v2.Compose( 

186 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)] 

187 ) 

188 dataset = MSTARTargets( 

189 rootdir, transform=transform 

190 ) 

191 X, y = dataset[0] 

192 

193 Displayed below are some examples for every class in the dataset. To plot them, we extracted 

194 only the magnitude of the signals although the data are indeed complex valued. 

195 

196 .. image:: ../assets/datasets/mstar.png 

197 :alt: Samples from MSTAR 

198 :width: 60% 

199 

200 

201 """ 

202 

203 def __init__(self, rootdir: str, transform=None): 

204 super().__init__() 

205 self.rootdir = pathlib.Path(rootdir) 

206 self.transform = transform 

207 

208 # The MSTAR dataset is composed of several sub-datasets 

209 # Each sub-dataset has a different layout 

210 # The dictionnary below maps the directory name of the sub-dataset 

211 # to the depth at which the target name is located in the directory structure 

212 # with respect to a datafile 

213 sub_datasets = { 

214 "MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2, 

215 "MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2, 

216 "MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2, 

217 "MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2, 

218 "MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3, 

219 } 

220 

221 # We collect all the samples from all the sub-datasets 

222 self.data_files = {} 

223 for sub_dataset, target_name_depth in sub_datasets.items(): 

224 sub_dir = self.rootdir / sub_dataset 

225 if not sub_dir.exists(): 

226 logging.warning(f"Directory {sub_dir} does not exist.") 

227 continue 

228 # Append the data files from the sub-dataset 

229 for key, value in gather_mstar_datafiles( 

230 sub_dir, target_name_depth 

231 ).items(): 

232 if key not in self.data_files: 

233 self.data_files[key] = [] 

234 self.data_files[key].extend(value) 

235 self.class_names = list(self.data_files.keys()) 

236 

237 # We then count how many samples have been loaded for all the classes 

238 self.num_data_files = {} 

239 self.tot_num_data_files = 0 

240 for key in self.class_names: 

241 self.num_data_files[key] = len(self.data_files[key]) 

242 self.tot_num_data_files += self.num_data_files[key] 

243 

244 logging.debug( 

245 f"Loaded {self.tot_num_data_files} MSTAR samples from the following classes : {self.class_names}." 

246 ) 

247 # List the number of samples per class 

248 for key in self.class_names: 

249 logging.debug(f"Class {key} : {self.num_data_files[key]} samples.") 

250 

251 def __len__(self) -> int: 

252 """ 

253 Returns the total number of samples 

254 """ 

255 return self.tot_num_data_files 

256 

257 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: 

258 """ 

259 Returns the sample at the given index. Applies the transform 

260 if provided. The type of the first component of the tuple 

261 depends on the provided transform. If None is provided, the 

262 sample is a complex valued numpy array. 

263 

264 Arguments: 

265 index : index of the sample to return 

266 

267 Returns: 

268 data : the sample 

269 class_idx : the index of the class in the class_names list 

270 """ 

271 

272 if index >= self.tot_num_data_files: 

273 raise IndexError 

274 

275 # We look for the class from which the sample will be taken 

276 for key in self.data_files.keys(): 

277 if index < self.num_data_files[key]: 

278 break 

279 index -= self.num_data_files[key] 

280 

281 filename = self.data_files[key][index] 

282 logging.debug(f"Loading the MSTAR file {filename}") 

283 

284 sample = MSTARSample(filename) 

285 class_idx = self.class_names.index(key) 

286 

287 data = sample.data 

288 if self.transform is not None: 

289 data = self.transform(data) 

290 

291 return data, class_idx