Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/mstar/dataset.py: 0%
109 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Tuple
25import pathlib
26import logging
27import struct
29# External imports
30import torch
31from torch.utils.data import Dataset
32import numpy as np
35def parse_header(fh) -> dict:
36 """
37 This function parses the PhoenixHeader from a file handle from the provided file handle.
39 It returns a dictionnary containing all the fields of the PhoenixHeader.
41 It raises an exception if the header is not valid, i.e. does not start with [PhoenixHeaderVer01.04] or [PhoenixHeaderVer01.05]
42 """
43 parsed_fields = {}
44 # There is one character at the very beginning of the file, before the header
45 _ = fh.readline()
46 start_line = fh.readline().strip()
47 accepted_headers = ["[PhoenixHeaderVer01.04]", "[PhoenixHeaderVer01.05]"]
48 if start_line not in accepted_headers:
49 raise ValueError(
50 f"Invalid header : {start_line}, expected one of {accepted_headers}"
51 )
52 next_line = fh.readline().strip()
53 while next_line != "[EndofPhoenixHeader]":
54 items = next_line.split("=")
55 key = items[0].strip()
56 value = items[1].strip()
57 parsed_fields[key] = value
58 next_line = fh.readline().strip()
60 return parsed_fields
63class MSTARSample:
64 """
65 This class implements a sample from the MSTAR dataset.
67 The extracted complex image is stored in the data attribute.
68 The data is a numpy array of shape (num_rows, num_cols, 1)
70 The header is stored in the header attribute. It contains all the fields
71 of the PhoenixHeader.
73 Arguments:
74 filename : the name of the file to load
75 """
77 def __init__(self, filename: str):
78 self.filename = pathlib.Path(filename)
80 with open(self.filename, "r", errors="replace") as fh:
81 # Read the header from the file
82 self.header = parse_header(fh)
84 num_rows = int(self.header["NumberOfRows"])
85 num_cols = int(self.header["NumberOfColumns"])
86 phoenix_header_length = int(self.header["PhoenixHeaderLength"])
87 native_header_length = int(self.header["native_header_length"])
89 fh.seek(phoenix_header_length + native_header_length)
91 sig_size = int(self.header["PhoenixSigSize"])
92 bytes_per_values = (
93 sig_size - phoenix_header_length - native_header_length
94 ) // (2 * num_rows * num_cols)
96 if bytes_per_values == 4:
97 # Read the data as float32
98 pass
99 elif bytes_per_values == 2:
100 # Read the data as uint16
101 pass
102 else:
103 raise ValueError(
104 f"Unsupported number of bytes per value : {bytes_per_values}"
105 )
107 # Read the data from the file
108 with open(self.filename, "rb") as fh:
109 fh.seek(phoenix_header_length + native_header_length)
111 data_bytes = fh.read(num_rows * num_cols * bytes_per_values)
112 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes)
113 magnitudes = np.array(unpacked).reshape(num_rows, num_cols)
115 data_bytes = fh.read(num_rows * num_cols * bytes_per_values)
116 unpacked = struct.unpack(">" + ("f" * num_cols * num_rows), data_bytes)
117 phases = np.array(unpacked).reshape(num_rows, num_cols)
119 self.data = magnitudes * np.exp(1j * phases)
120 self.data = self.data[:, :, np.newaxis]
123def gather_mstar_datafiles(rootdir: pathlib.Path, target_name_depth: int = 1) -> dict:
124 """
125 This function gathers all the MSTAR datafiles from the root directory
126 It looks for files named HBxxxx that are data files (containing a PhoenixHeader).
128 The assigned target name is the name of the directory, or parent directory, at the target_name_depth level.
129 """
131 data_files = {}
132 for filename in rootdir.glob("**/HB*"):
133 if not filename.is_file():
134 continue
136 try:
137 with open(filename, "r", errors="replace") as fh:
138 _ = parse_header(fh)
139 # sample = MSTARSample(filename)
140 except Exception as e:
141 logging.debug(
142 f"The file {filename} failed to be loaded as a MSTAR sample: {e}"
143 )
144 continue
146 target_name = filename.parts[-target_name_depth]
147 if target_name not in data_files:
148 data_files[target_name] = []
150 logging.debug(f"Successfully parsed {filename} as a {target_name} sample.")
151 data_files[target_name].append(filename)
152 return data_files
155class MSTARTargets(Dataset):
156 """
157 This class implements a PyTorch Dataset for the MSTAR dataset.
159 The MSTAR dataset is composed of several sub-datasets. The datasets must
160 be downloaded manually because they require authentication.
162 To download these datasets, you must register at the following address: https://www.sdms.afrl.af.mil/index.php?collection=mstar
164 This dataset object expects all the datasets to be unpacked in the same directory. We can parse the following :
166 - MSTAR_PUBLIC_T_72_VARIANTS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
167 - MSTAR_PUBLIC_T_72_VARIANTS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=variants
168 - MSTAR_PUBLIC_MIXED_TARGETS_CD1 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
169 - MSTAR_PUBLIC_MIXED_TARGETS_CD2 : https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=mixed
170 - MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY :
171 https://www.sdms.afrl.af.mil/index.php?collection=mstar&page=targets
173 Arguments:
174 rootdir : str
175 transform : the transform applied on the input complex valued array
177 Note:
178 An example usage :
180 .. code-block:: python
182 import torchcvnn
183 from torchcvnn.datasets import MSTARTargets
185 transform = v2.Compose(
186 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)]
187 )
188 dataset = MSTARTargets(
189 rootdir, transform=transform
190 )
191 X, y = dataset[0]
193 Displayed below are some examples for every class in the dataset. To plot them, we extracted
194 only the magnitude of the signals although the data are indeed complex valued.
196 .. image:: ../assets/datasets/mstar.png
197 :alt: Samples from MSTAR
198 :width: 60%
201 """
203 def __init__(self, rootdir: str, transform=None):
204 super().__init__()
205 self.rootdir = pathlib.Path(rootdir)
206 self.transform = transform
208 rootdir = pathlib.Path(rootdir)
210 # The MSTAR dataset is composed of several sub-datasets
211 # Each sub-dataset has a different layout
212 # The dictionnary below maps the directory name of the sub-dataset
213 # to the depth at which the target name is located in the directory structure
214 # with respect to a datafile
215 sub_datasets = {
216 "MSTAR_PUBLIC_T_72_VARIANTS_CD1": 2,
217 "MSTAR_PUBLIC_T_72_VARIANTS_CD2": 2,
218 "MSTAR_PUBLIC_MIXED_TARGETS_CD1": 2,
219 "MSTAR_PUBLIC_MIXED_TARGETS_CD2": 2,
220 "MSTAR_PUBLIC_TARGETS_CHIPS_T72_BMP2_BTR70_SLICY": 3,
221 }
223 # We collect all the samples from all the sub-datasets
224 self.data_files = {}
225 for sub_dataset, target_name_depth in sub_datasets.items():
226 sub_dir = rootdir / sub_dataset
227 if not sub_dir.exists():
228 logging.warning(f"Directory {sub_dir} does not exist.")
229 continue
230 # Append the data files from the sub-dataset
231 for key, value in gather_mstar_datafiles(
232 sub_dir, target_name_depth
233 ).items():
234 if key not in self.data_files:
235 self.data_files[key] = []
236 self.data_files[key].extend(value)
237 self.class_names = list(self.data_files.keys())
239 # We then count how many samples have been loaded for all the classes
240 self.num_data_files = {}
241 self.tot_num_data_files = 0
242 for key in self.class_names:
243 self.num_data_files[key] = len(self.data_files[key])
244 self.tot_num_data_files += self.num_data_files[key]
246 logging.debug(
247 f"Loaded {self.tot_num_data_files} MSTAR samples from the following classes : {self.class_names}."
248 )
249 # List the number of samples per class
250 for key in self.class_names:
251 logging.debug(f"Class {key} : {self.num_data_files[key]} samples.")
253 def __len__(self) -> int:
254 """
255 Returns the total number of samples
256 """
257 return self.tot_num_data_files
259 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
260 """
261 Returns the sample at the given index. Applies the transform
262 if provided. The type of the first component of the tuple
263 depends on the provided transform. If None is provided, the
264 sample is a complex valued numpy array.
266 Arguments:
267 index : index of the sample to return
269 Returns:
270 data : the sample
271 class_idx : the index of the class in the class_names list
272 """
274 if index >= self.tot_num_data_files:
275 raise IndexError
277 # We look for the class from which the sample will be taken
278 for key in self.data_files.keys():
279 if index < self.num_data_files[key]:
280 break
281 index -= self.num_data_files[key]
283 filename = self.data_files[key][index]
284 logging.debug(f"Loading the MSTAR file {filename}")
286 sample = MSTARSample(filename)
287 class_idx = self.class_names.index(key)
289 data = sample.data
290 if self.transform is not None:
291 data = self.transform(data)
293 return data, class_idx