Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/mstar/dataset.py: 0%

109 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Tuple 

25import pathlib 

26import logging 

27import struct 

28 

29# External imports 

30import torch 

31from torch.utils.data import Dataset 

32import numpy as np 

33 

34 

35def parse_header(fh) -> dict: 

36 """ 

37 This function parses the PhoenixHeader from a file handle from the provided file handle. 

38 

39 It returns a dictionnary containing all the fields of the PhoenixHeader. 

40 

41 It raises an exception if the header is not valid, i.e. does not start with [PhoenixHeaderVer01.04] or [PhoenixHeaderVer01.05] 

42 """ 

43 parsed_fields = {} 

44 # There is one character at the very beginning of the file, before the header 

45 _ = fh.readline() 

46 start_line = fh.readline().strip() 

47 accepted_headers = ["[PhoenixHeaderVer01.04]", "[PhoenixHeaderVer01.05]"] 

48 if start_line not in accepted_headers: 

49 raise ValueError( 

50 f"Invalid header : {start_line}, expected one of {accepted_headers}" 

51 ) 

52 next_line = fh.readline().strip() 

53 while next_line != "[EndofPhoenixHeader]": 

54 items = next_line.split("=") 

55 key = items[0].strip() 

56 value = items[1].strip() 

57 parsed_fields[key] = value 

58 next_line = fh.readline().strip() 

59 

60 return parsed_fields 

61 

62 

63class MSTARSample: 

64 """ 

65 This class implements a sample from the MSTAR dataset. 

66 

67 The extracted complex image is stored in the data attribute. 

68 The data is a numpy array of shape (num_rows, num_cols, 1) 

69 

70 The header is stored in the header attribute. It contains all the fields 

71 of the PhoenixHeader. 

72 

73 Arguments: 

74 filename : the name of the file to load 

75 """ 

76 

77 def __init__(self, filename: str): 

78 self.filename = pathlib.Path(filename) 

79 

80 with open(self.filename, "r", errors="replace") as fh: 

81 # Read the header from the file 

82 self.header = parse_header(fh) 

83 

84 num_rows = int(self.header["NumberOfRows"]) 

85 num_cols = int(self.header["NumberOfColumns"]) 

86 phoenix_header_length = int(self.header["PhoenixHeaderLength"]) 

87 native_header_length = int(self.header["native_header_length"]) 

88 

89 fh.seek(phoenix_header_length + native_header_length) 

90 

91 sig_size = int(self.header["PhoenixSigSize"]) 

92 bytes_per_values = ( 

93 sig_size - phoenix_header_length - native_header_length 

94 ) // (2 * num_rows * num_cols) 

95 

96 if bytes_per_values == 4: 

97 # Read the data as float32 

98 pass 

99 elif bytes_per_values == 2: 

100 # Read the data as uint16 

101 pass 

102 else: 

103 raise ValueError( 

104 f"Unsupported number of bytes per value : {bytes_per_values}" 

105 ) 

106 

107 # Read the data from the file 

108 with open(self.filename, "rb") as fh: 

109 fh.seek(phoenix_header_length + native_header_length) 

110 

111 data_bytes = fh.read(num_rows * num_cols * bytes_per_values) 

112 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes) 

113 magnitudes = np.array(unpacked).reshape(num_rows, num_cols) 

114 

115 data_bytes = fh.read(num_rows * num_cols * bytes_per_values) 

116 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes) 

117 phases = np.array(unpacked).reshape(num_rows, num_cols) 

118 

119 self.data = magnitudes * np.exp(1j * phases) 

120 self.data = self.data[:, :, np.newaxis] 

121 

122 

123def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) -> dict: 

124 """ 

125 This function gathers all the MSTAR datafiles from the root directory 

126 It looks for files named HBxxxx that are data files (containing a PhoenixHeader). 

127 

128 The assigned target name is the name of the directory, or parent directory, at the target_name_depth level. 

129 """ 

130 

131 data_files = {} 

132 for filename in rootdir.glob("**/HB*"): 

133 if not filename.is_file(): 

134 continue 

135 

136 try: 

137 with open(filename, "r", errors="replace") as fh: 

138 _ = parse_header(fh) 

139 # sample = MSTARSample(filename) 

140 except Exception as e: 

141 logging.debug( 

142 f"The file {filename} failed to be loaded as a MSTAR sample: {e}" 

143 ) 

144 continue 

145 

146 target_name = filename.parts[-target_name_depth] 

147 if target_name not in data_files: 

148 data_files[target_name] = [] 

149 

150 logging.debug(f"Successfully parsed {filename} as a {target_name} sample.") 

151 data_files[target_name].append(filename) 

152 return data_files 

153 

154 

155class MSTARTargets(Dataset): 

156 """ 

157 This class implements a PyTorch Dataset for the MSTAR dataset. 

158 

159 The MSTAR dataset is composed of several sub-datasets. The datasets must 

160 be downloaded manually because they require authentication. 

161 

162 To download these datasets, you must register at the following address: https://www.sdms.afrl.af.mil/index.php?collection=mstar 

163 

164 This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following : 

165 

166 - MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants 

167 - MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants 

168 - MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed 

169 - MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed 

170 - MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY : 

171 https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=targets 

172 

173 Arguments: 

174 rootdir : str 

175 transform : the transform applied on the input complex valued array 

176 

177 Note: 

178 An example usage : 

179 

180 .. code-block:: python 

181 

182 import torchcvnn 

183 from torchcvnn.datasets import MSTARTargets 

184 

185 transform = v2.Compose( 

186 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)] 

187 ) 

188 dataset = MSTARTargets( 

189 rootdir, transform=transform 

190 ) 

191 X, y = dataset[0] 

192 

193 Displayed below are some examples for every class in the dataset. To plot them, we extracted 

194 only the magnitude of the signals although the data are indeed complex valued. 

195 

196 .. image:: ../assets/datasets/mstar.png 

197 :alt: Samples from MSTAR 

198 :width: 60% 

199 

200 

201 """ 

202 

203 def __init__(self, rootdir: str, transform=None): 

204 super().__init__() 

205 self.rootdir = pathlib.Path(rootdir) 

206 self.transform = transform 

207 

208 rootdir = pathlib.Path(rootdir) 

209 

210 # The MSTAR dataset is composed of several sub-datasets 

211 # Each sub-dataset has a different layout 

212 # The dictionnary below maps the directory name of the sub-dataset 

213 # to the depth at which the target name is located in the directory structure 

214 # with respect to a datafile 

215 sub_datasets = { 

216 "MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2, 

217 "MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2, 

218 "MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2, 

219 "MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2, 

220 "MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3, 

221 } 

222 

223 # We collect all the samples from all the sub-datasets 

224 self.data_files = {} 

225 for sub_dataset, target_name_depth in sub_datasets.items(): 

226 sub_dir = rootdir / sub_dataset 

227 if not sub_dir.exists(): 

228 logging.warning(f"Directory {sub_dir} does not exist.") 

229 continue 

230 # Append the data files from the sub-dataset 

231 for key, value in gather_mstar_datafiles( 

232 sub_dir, target_name_depth 

233 ).items(): 

234 if key not in self.data_files: 

235 self.data_files[key] = [] 

236 self.data_files[key].extend(value) 

237 self.class_names = list(self.data_files.keys()) 

238 

239 # We then count how many samples have been loaded for all the classes 

240 self.num_data_files = {} 

241 self.tot_num_data_files = 0 

242 for key in self.class_names: 

243 self.num_data_files[key] = len(self.data_files[key]) 

244 self.tot_num_data_files += self.num_data_files[key] 

245 

246 logging.debug( 

247 f"Loaded {self.tot_num_data_files} MSTAR samples from the following classes : {self.class_names}." 

248 ) 

249 # List the number of samples per class 

250 for key in self.class_names: 

251 logging.debug(f"Class {key} : {self.num_data_files[key]} samples.") 

252 

253 def __len__(self) -> int: 

254 """ 

255 Returns the total number of samples 

256 """ 

257 return self.tot_num_data_files 

258 

259 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: 

260 """ 

261 Returns the sample at the given index. Applies the transform 

262 if provided. The type of the first component of the tuple 

263 depends on the provided transform. If None is provided, the 

264 sample is a complex valued numpy array. 

265 

266 Arguments: 

267 index : index of the sample to return 

268 

269 Returns: 

270 data : the sample 

271 class_idx : the index of the class in the class_names list 

272 """ 

273 

274 if index >= self.tot_num_data_files: 

275 raise IndexError 

276 

277 # We look for the class from which the sample will be taken 

278 for key in self.data_files.keys(): 

279 if index < self.num_data_files[key]: 

280 break 

281 index -= self.num_data_files[key] 

282 

283 filename = self.data_files[key][index] 

284 logging.debug(f"Loading the MSTAR file {filename}") 

285 

286 sample = MSTARSample(filename) 

287 class_idx = self.class_names.index(key) 

288 

289 data = sample.data 

290 if self.transform is not None: 

291 data = self.transform(data) 

292 

293 return data, class_idx