Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/sample/dataset.py: 0%
62 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24import requests
25import pathlib
26import logging
27from typing import Tuple
29# External imports
30import torch
31from torch.utils.data import Dataset
32import scipy.io
33import tqdm
34import numpy as np
36# Local imports
37from .filelist import filelist
39SAMPLE_base_link = "https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public/raw/refs/heads/master/mat_files/"
42class SAMPLE(Dataset):
43 """
44 The SAMPLE dataset is made partly from real data provided by MSTAR and partly from synthetic data.
46 The dataset is public and will be downloaded if requested and missing on drive.
48 It is made of 10 classes of military vehicles: 2s1, bmp2, btr70, m1, m2, m35, m548, m60, t72, zsu23
50 Arguments:
51 rootdir (str): Path to the root directory where the dataset is stored or will be downloaded
52 transform (torchvision.transforms.Compose): A list of torchvision transforms to apply to the complex image
53 download (bool): Whether to download the data if missing on disk
55 Note:
56 An example usage :
58 .. code-block:: python
60 import torchcvnn
61 from torchcvnn.datasets import SAMPLE
63 transform = v2.Compose(
64 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)]
65 )
66 dataset = SAMPLE(
67 rootdir, transform=transform, download=True
68 )
69 X, y = dataset[0]
71 Displayed below are some examples drawn randomly from SAMPLE. To plot them, we extracted
72 only the magnitude of the signals although the data are indeed complex valued.
74 .. image:: ../assets/datasets/SAMPLE.png
75 :alt: Samples from MSTAR
76 :width: 60%
78 """
80 def __init__(self, rootdir: str, transform=None, download: bool = False):
81 super().__init__()
82 self.rootdir = pathlib.Path(rootdir)
83 self.transform = transform
85 self.class_names = list(filelist["real"].keys())
87 # We look into rootdir if the data are available
88 self.tot_num_samples = 0
89 for cl in self.class_names:
90 self.tot_num_samples += len(filelist["real"][cl])
91 self.tot_num_samples += len(filelist["synth"][cl])
93 self.data_files = {}
94 self.num_data_files = {}
95 self.tot_num_data_files = 0
97 pbar = tqdm.tqdm(total=self.tot_num_samples)
98 for cl in self.class_names:
99 for mode in ["real", "synth"]:
100 for filename in filelist[mode][cl]:
101 filepath = self.rootdir / cl / filename
102 if not filepath.exists():
103 if download:
104 url = f"{SAMPLE_base_link}{mode}/{cl}/{filename}"
105 self.download_file(url, filepath)
106 else:
107 raise FileNotFoundError(f"{filepath} not found")
108 pbar.update(1)
109 self.data_files[cl] = list((self.rootdir / cl).glob("*.mat"))
110 self.num_data_files[cl] = len(self.data_files[cl])
111 self.tot_num_data_files += self.num_data_files[cl]
113 def download_file(self, url: str, filepath: pathlib.Path):
114 """
115 Download a file from an URL and save it on disk
117 Args:
118 url (str): URL to download the file from
119 filepath (pathlib.Path): Path to save the file
120 """
121 logging.debug(f"Downloading {url} to {filepath}")
122 # Ensure the target directory exists
123 filepath.parent.mkdir(parents=True, exist_ok=True)
125 # Donwload and save the file on disk
126 response = requests.get(url)
127 with open(filepath, "wb") as fh:
128 fh.write(response.content)
130 def __len__(self) -> int:
131 """
132 Return the total number of samples in the dataset
133 """
134 return self.tot_num_data_files
136 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
137 """
138 Return a sample from the dataset
139 """
140 if index >= self.tot_num_data_files:
141 raise IndexError
143 # We look for the class from which the sample will be taken
144 for key in self.data_files.keys():
145 if index < self.num_data_files[key]:
146 break
147 index -= self.num_data_files[key]
149 filename = self.data_files[key][index]
150 logging.debug(f"Reading SAMPLE file : {filename}")
152 data = scipy.io.loadmat(filename)
154 # Below are the keys available in the mat files
155 # dict_keys(['__header__', '__version__', '__globals__', 'aligned', 'azimuth', 'bandwidth', 'center_freq', 'complex_img', 'complex_img_unshifted', 'elevation', 'explanation', 'range_pixel_spacing', 'range_resolution', 'source_mstar_file', 'target_name', 'taylor_weights', 'xrange_pixel_spacing', 'xrange_resolution']
157 meta = {
158 k: data[k]
159 for k in [
160 "azimuth",
161 "elevation",
162 "bandwidth",
163 "center_freq",
164 "range_pixel_spacing",
165 "range_resolution",
166 "xrange_pixel_spacing",
167 "xrange_resolution",
168 ]
169 }
171 complex_img = data["complex_img"][:, :, np.newaxis]
173 class_idx = self.class_names.index(key)
175 if self.transform is not None:
176 complex_img = self.transform(complex_img)
178 return complex_img, class_idx, meta