Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/sample/dataset.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24import requests 

25import pathlib 

26import logging 

27from typing import Tuple 

28 

29# External imports 

30import torch 

31from torch.utils.data import Dataset 

32import scipy.io 

33import tqdm 

34import numpy as np 

35 

36# Local imports 

37from .filelist import filelist 

38 

39SAMPLE_base_link = "https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public/raw/refs/heads/master/mat_files/" 

40 

41 

42class SAMPLE(Dataset): 

43 """ 

44 The SAMPLE dataset is made partly from real data provided by MSTAR and partly from synthetic data. 

45 

46 The dataset is public and will be downloaded if requested and missing on drive. 

47 

48 It is made of 10 classes of military vehicles: 2s1, bmp2, btr70, m1, m2, m35, m548, m60, t72, zsu23 

49 

50 Arguments: 

51 rootdir (str): Path to the root directory where the dataset is stored or will be downloaded 

52 transform (torchvision.transforms.Compose): A list of torchvision transforms to apply to the complex image 

53 download (bool): Whether to download the data if missing on disk 

54 

55 Note: 

56 An example usage : 

57 

58 .. code-block:: python 

59 

60 import torchcvnn 

61 from torchcvnn.datasets import SAMPLE 

62 

63 transform = v2.Compose( 

64 transforms=[v2.ToImage(), v2.Resize(128), v2.CenterCrop(128)] 

65 ) 

66 dataset = SAMPLE( 

67 rootdir, transform=transform, download=True 

68 ) 

69 X, y = dataset[0] 

70 

71 Displayed below are some examples drawn randomly from SAMPLE. To plot them, we extracted 

72 only the magnitude of the signals although the data are indeed complex valued. 

73 

74 .. image:: ../assets/datasets/SAMPLE.png 

75 :alt: Samples from MSTAR 

76 :width: 60% 

77 

78 """ 

79 

80 def __init__(self, rootdir: str, transform=None, download: bool = False): 

81 super().__init__() 

82 self.rootdir = pathlib.Path(rootdir) 

83 self.transform = transform 

84 

85 self.class_names = list(filelist["real"].keys()) 

86 

87 # We look into rootdir if the data are available 

88 self.tot_num_samples = 0 

89 for cl in self.class_names: 

90 self.tot_num_samples += len(filelist["real"][cl]) 

91 self.tot_num_samples += len(filelist["synth"][cl]) 

92 

93 self.data_files = {} 

94 self.num_data_files = {} 

95 self.tot_num_data_files = 0 

96 

97 pbar = tqdm.tqdm(total=self.tot_num_samples) 

98 for cl in self.class_names: 

99 for mode in ["real", "synth"]: 

100 for filename in filelist[mode][cl]: 

101 filepath = self.rootdir / cl / filename 

102 if not filepath.exists(): 

103 if download: 

104 url = f"{SAMPLE_base_link}{mode}/{cl}/{filename}" 

105 self.download_file(url, filepath) 

106 else: 

107 raise FileNotFoundError(f"{filepath} not found") 

108 pbar.update(1) 

109 self.data_files[cl] = list((self.rootdir / cl).glob("*.mat")) 

110 self.num_data_files[cl] = len(self.data_files[cl]) 

111 self.tot_num_data_files += self.num_data_files[cl] 

112 

113 def download_file(self, url: str, filepath: pathlib.Path): 

114 """ 

115 Download a file from an URL and save it on disk 

116 

117 Args: 

118 url (str): URL to download the file from 

119 filepath (pathlib.Path): Path to save the file 

120 """ 

121 logging.debug(f"Downloading {url} to {filepath}") 

122 # Ensure the target directory exists 

123 filepath.parent.mkdir(parents=True, exist_ok=True) 

124 

125 # Donwload and save the file on disk 

126 response = requests.get(url) 

127 with open(filepath, "wb") as fh: 

128 fh.write(response.content) 

129 

130 def __len__(self) -> int: 

131 """ 

132 Return the total number of samples in the dataset 

133 """ 

134 return self.tot_num_data_files 

135 

136 def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: 

137 """ 

138 Return a sample from the dataset 

139 """ 

140 if index >= self.tot_num_data_files: 

141 raise IndexError 

142 

143 # We look for the class from which the sample will be taken 

144 for key in self.data_files.keys(): 

145 if index < self.num_data_files[key]: 

146 break 

147 index -= self.num_data_files[key] 

148 

149 filename = self.data_files[key][index] 

150 logging.debug(f"Reading SAMPLE file : {filename}") 

151 

152 data = scipy.io.loadmat(filename) 

153 

154 # Below are the keys available in the mat files 

155 # dict_keys(['__header__', '__version__', '__globals__', 'aligned', 'azimuth', 'bandwidth', 'center_freq', 'complex_img', 'complex_img_unshifted', 'elevation', 'explanation', 'range_pixel_spacing', 'range_resolution', 'source_mstar_file', 'target_name', 'taylor_weights', 'xrange_pixel_spacing', 'xrange_resolution'] 

156 

157 meta = { 

158 k: data[k] 

159 for k in [ 

160 "azimuth", 

161 "elevation", 

162 "bandwidth", 

163 "center_freq", 

164 "range_pixel_spacing", 

165 "range_resolution", 

166 "xrange_pixel_spacing", 

167 "xrange_resolution", 

168 ] 

169 } 

170 

171 complex_img = data["complex_img"][:, :, np.newaxis] 

172 

173 class_idx = self.class_names.index(key) 

174 

175 if self.transform is not None: 

176 complex_img = self.transform(complex_img) 

177 

178 return complex_img, class_idx, meta