Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / datasets / mstar / dataset.py: 0%
108 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 14:36 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 14:36 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Tuple
25import pathlib
26import logging
27import struct
29# External imports
30import torch
31from torch.utils.data import Dataset
32import numpy as np
35def parse_header(fh) -> dict:
36 """
37 This function parses the PhoenixHeader from a file handle from the provided file handle.
39 It returns a dictionnary containing all the fields of the PhoenixHeader.
41 It raises an exception if the header is not valid, i.e. does not start with [PhoenixHeaderVer01.04] or [PhoenixHeaderVer01.05]
42 """
43 parsed_fields = {}
44 # There is one character at the very beginning of the file, before the header
45 _ = fh.readline()
46 start_line = fh.readline().strip()
47 accepted_headers = ["[PhoenixHeaderVer01.04]", "[PhoenixHeaderVer01.05]"]
48 if start_line not in accepted_headers:
49 raise ValueError(
50 f"Invalid header : {start_line}, expected one of {accepted_headers}"
51 )
52 next_line = fh.readline().strip()
53 while next_line != "[EndofPhoenixHeader]":
54 items = next_line.split("=")
55 key = items[0].strip()
56 value = items[1].strip()
57 parsed_fields[key] = value
58 next_line = fh.readline().strip()
60 return parsed_fields
63class MSTARSample:
64 """
65 This class implements a sample from the MSTAR dataset.
67 The extracted complex image is stored in the data attribute.
68 The data is a numpy array of shape (num_rows, num_cols, 1)
70 The header is stored in the header attribute. It contains all the fields
71 of the PhoenixHeader.
73 Arguments:
74 filename : the name of the file to load
75 """
77 def __init__(self, filename: str):
78 self.filename = pathlib.Path(filename)
80 with open(self.filename, "r", errors="replace") as fh:
81 # Read the header from the file
82 self.header = parse_header(fh)
84 num_rows = int(self.header["NumberOfRows"])
85 num_cols = int(self.header["NumberOfColumns"])
86 phoenix_header_length = int(self.header["PhoenixHeaderLength"])
87 native_header_length = int(self.header["native_header_length"])
89 fh.seek(phoenix_header_length + native_header_length)
91 sig_size = int(self.header["PhoenixSigSize"])
92 bytes_per_values = (
93 sig_size - phoenix_header_length - native_header_length
94 ) // (2 * num_rows * num_cols)
96 if bytes_per_values == 4:
97 # Read the data as float32
98 pass
99 elif bytes_per_values == 2:
100 # Read the data as uint16
101 pass
102 else:
103 raise ValueError(
104 f"Unsupported number of bytes per value : {bytes_per_values}"
105 )
107 # Read the data from the file
108 with open(self.filename, "rb") as fh:
109 fh.seek(phoenix_header_length + native_header_length)
111 data_bytes = fh.read(num_rows * num_cols * bytes_per_values)
112 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes)
113 magnitudes = np.array(unpacked).reshape(num_rows, num_cols)
115 data_bytes = fh.read(num_rows * num_cols * bytes_per_values)
116 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes)
117 phases = np.array(unpacked).reshape(num_rows, num_cols)
119 self.data = magnitudes * np.exp(1j * phases)
120 self.data = self.data[:, :, np.newaxis]
123def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) -> dict:
124 """
125 This function gathers all the MSTAR datafiles from the root directory
126 It looks for files named HBxxxx that are data files (containing a PhoenixHeader).
128 The assigned target name is the name of the directory, or parent directory, at the target_name_depth level.
129 """
131 data_files = {}
132 for filename in rootdir.glob("**/HB*"):
133 if not filename.is_file():
134 continue
136 try:
137 with open(filename, "r", errors="replace") as fh:
138 _ = parse_header(fh)
139 # sample = MSTARSample(filename)
140 except Exception as e:
141 logging.debug(
142 f"The file {filename} failed to be loaded as a MSTAR sample: {e}"
143 )
144 continue
146 target_name = filename.parts[-target_name_depth]
147 if target_name not in data_files:
148 data_files[target_name] = []
150 logging.debug(f"Successfully parsed {filename} as a {target_name} sample.")
151 data_files[target_name].append(filename)
152 return data_files
155class MSTARTargets(Dataset):
156 """
157 This class implements a PyTorch Dataset for the MSTAR dataset.
159 The MSTAR dataset is composed of several sub-datasets. The datasets must
160 be downloaded manually because they require authentication.
162 To download these datasets, you must register at the following address: https://www.sdms.afrl.af.mil/index.php?collection=mstar
164 This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following :
166 - MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
167 - MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
168 - MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
169 - MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
170 - MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY :
171 https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=targets
173 Arguments:
174 rootdir : str
175 transform : the transform applied on the input complex valued array
177 Note:
178 An example usage :
180 .. code-block:: python
182 import torchcvnn
183 from torchcvnn.datasets import MSTARTargets
185 transform = v2.Compose(
186 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)]
187 )
188 dataset = MSTARTargets(
189 rootdir, transform=transform
190 )
191 X, y = dataset[0]
193 Displayed below are some examples for every class in the dataset. To plot them, we extracted
194 only the magnitude of the signals although the data are indeed complex valued.
196 .. image:: ../assets/datasets/mstar.png
197 :alt: Samples from MSTAR
198 :width: 60%
201 """
203 def __init__(self, rootdir: str, transform=None):
204 super().__init__()
205 self.rootdir = pathlib.Path(rootdir)
206 self.transform = transform
208 # The MSTAR dataset is composed of several sub-datasets
209 # Each sub-dataset has a different layout
210 # The dictionnary below maps the directory name of the sub-dataset
211 # to the depth at which the target name is located in the directory structure
212 # with respect to a datafile
213 sub_datasets = {
214 "MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2,
215 "MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2,
216 "MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2,
217 "MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2,
218 "MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3,
219 }
221 # We collect all the samples from all the sub-datasets
222 self.data_files = {}
223 for sub_dataset, target_name_depth in sub_datasets.items():
224 sub_dir = self.rootdir / sub_dataset
225 if not sub_dir.exists():
226 logging.warning(f"Directory {sub_dir} does not exist.")
227 continue
228 # Append the data files from the sub-dataset
229 for key, value in gather_mstar_datafiles(
230 sub_dir, target_name_depth
231 ).items():
232 if key not in self.data_files:
233 self.data_files[key] = []
234 self.data_files[key].extend(value)
235 self.class_names = list(self.data_files.keys())
237 # We then count how many samples have been loaded for all the classes
238 self.num_data_files = {}
239 self.tot_num_data_files = 0
240 for key in self.class_names:
241 self.num_data_files[key] = len(self.data_files[key])
242 self.tot_num_data_files += self.num_data_files[key]
244 logging.debug(
245 f"Loaded {self.tot_num_data_files} MSTAR samples from the following classes : {self.class_names}."
246 )
247 # List the number of samples per class
248 for key in self.class_names:
249 logging.debug(f"Class {key} : {self.num_data_files[key]} samples.")
251 def __len__(self) -> int:
252 """
253 Returns the total number of samples
254 """
255 return self.tot_num_data_files
257 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
258 """
259 Returns the sample at the given index. Applies the transform
260 if provided. The type of the first component of the tuple
261 depends on the provided transform. If None is provided, the
262 sample is a complex valued numpy array.
264 Arguments:
265 index : index of the sample to return
267 Returns:
268 data : the sample
269 class_idx : the index of the class in the class_names list
270 """
272 if index >= self.tot_num_data_files:
273 raise IndexError
275 # We look for the class from which the sample will be taken
276 for key in self.data_files.keys():
277 if index < self.num_data_files[key]:
278 break
279 index -= self.num_data_files[key]
281 filename = self.data_files[key][index]
282 logging.debug(f"Loading the MSTAR file {filename}")
284 sample = MSTARSample(filename)
285 class_idx = self.class_names.index(key)
287 data = sample.data
288 if self.transform is not None:
289 data = self.transform(data)
291 return data, class_idx