Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / datasets / s1slc.py: 0%
40 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 14:36 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-12 14:36 +0000
1# MIT License
3# Copyright (c) 2025 Quentin Gabot
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
24# Standard imports
25import os
27# External imports
28import numpy as np
29from torch.utils.data import Dataset
32class S1SLC(Dataset):
33 r"""
34 The Polarimetric SAR dataset with the labels provided by
35 https://ieee-dataport.org/open-access/s1slccvdl-complex-valued-annotated-single-look-complex-sentinel-1-sar-dataset-complex
37 We expect the data to be already downloaded and available on your drive.
39 Arguments:
40 root: the top root dir where the data are expected. The data should be organized as follows: Sao Paulo/HH.npy, Sao Paulo/HV.npy, Sao Paulo/Labels.npy, Houston/HH.npy, Houston/HV.npy, Houston/Labels.npy, Chicago/HH.npy, Chicago/HV.npy, Chicago/Labels.npy
41 transform : the transform applied the cropped image
42 lazy_loading : if True, the data is loaded only when requested. If False, the data is loaded at the initialization of the dataset.
44 Note:
45 An example usage :
47 .. code-block:: python
49 import torchcvnn
50 from torchcvnn.datasets import S1SLC
52 def transform(patches):
53 # If you wish, you could filter out some polarizations
54 # S1SLC provides the dual HH, HV polarizations
55 patches = [np.abs(patchi) for _, patchi in patches.items()]
56 return np.stack(patches)
58 dataset = S1SLC(rootdir, transform=transform
59 X, y = dataset[0]
61 """
63 def __init__(self, root, transform=None, lazy_loading=False):
64 if lazy_loading is True:
65 raise DeprecationWarning(
66 "Lazy loading is no longer supported for S1SLC dataset."
67 )
69 self.transform = transform
70 # Get list of subfolders in the root path
71 subfolders = [
72 os.path.join(root, name)
73 for name in os.listdir(root)
74 if os.path.isdir(os.path.join(root, name))
75 ]
77 self.data = {}
78 self.classes = set()
80 for subfolder in subfolders:
81 # Define paths to the .npy files
82 hh_path = os.path.join(subfolder, "HH.npy")
83 hv_path = os.path.join(subfolder, "HV.npy")
84 labels_path = os.path.join(subfolder, "Labels.npy")
86 # Load the .npy files (using a memory map for memory efficiency)
87 hh = np.load(hh_path, mmap_mode="r")
88 hv = np.load(hv_path, mmap_mode="r")
90 # Load labels and convert to 0-indexed
91 labels = np.load(labels_path, mmap_mode="r")
92 labels = labels.astype(int).squeeze() - 1
94 # put labels in the set of classes
95 self.classes.update(list(labels))
97 self.data[subfolder] = {
98 "hh": hh,
99 "hv": hv,
100 "labels": labels,
101 }
103 self.classes = [int(c) for c in self.classes]
105 def __len__(self):
106 return sum(len(city_data["labels"]) for city_data in self.data.values())
108 def find_city_and_local_idx(self, idx):
109 """
110 Given a global index, find the corresponding city and local index within that city's data.
111 This is done in O(n_cities) time, which is acceptable since n_cities=3
112 """
113 cumulative = 0
114 for city, city_data in self.data.items():
115 city_size = len(city_data["labels"])
116 if idx < cumulative + city_size:
117 local_idx = idx - cumulative
118 return city, local_idx
119 cumulative += city_size
120 raise IndexError("Index out of range")
122 def __getitem__(self, idx):
123 city, local_idx = self.find_city_and_local_idx(idx)
125 image = np.stack(
126 [
127 self.data[city]["hh"][local_idx],
128 self.data[city]["hv"][local_idx],
129 ]
130 )
131 label = self.data[city]["labels"][local_idx]
133 if self.transform:
134 image = self.transform(image)
136 return image, label