Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/polsf.py: 0%
27 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Quentin Gabot
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24import pathlib
25from typing import Tuple, Any
27# External imports
28import numpy as np
29from torch.utils.data import Dataset
30from PIL import Image
32# Local imports
33from .alos2 import ALOSDataset
36class PolSFDataset(Dataset):
37 r"""
38 The Polarimetric SAR dataset with the labels provided by
39 https://ietr-lab.univ-rennes1.fr/polsarpro-bio/san-francisco/
41 We expect the data to be already downloaded and available on your drive.
43 Arguments:
44 root: the top root dir where the data are expected
45 transform : the transform applied the cropped image
46 patch_size: the dimensions of the patches to consider (rows, cols)
47 patch_stride: the shift between two consecutive patches, default:patch_size
49 Note:
50 An example usage :
52 .. code-block:: python
54 import torchcvnn
55 from torchcvnn.datasets import PolSFDataset
57 def transform_patches(patches):
58 # We keep all the patches and get the spectrum
59 # from it
60 # If you wish, you could filter out some polarizations
61 # PolSF provides the four HH, HV, VH, VV
62 patches = [np.abs(patchi) for _, patchi in patches.items()]
63 return np.stack(patches)
65 dataset = PolSFDataset(rootdir, patch_size=((512, 512)), transform=transform_patches
66 X, y = dataset[0]
68 Displayed below are example patches with patch sizes :math:`512 \times 512`
69 with the labels overlayed
71 .. figure:: ../assets/datasets/polsf.png
72 :alt: Patches from the PolSF dataset
73 :width: 100%
74 :align: center
76 """
78 """
79 Class names
80 """
81 classes = [
82 "0 - unlabel",
83 "1 - Montain",
84 "2 - Water",
85 "3 - Vegetation",
86 "4 - High-Density Urban",
87 "5 - Low-Density Urban",
88 "6 - Developd",
89 ]
91 def __init__(
92 self,
93 root: str,
94 transform=None,
95 patch_size: tuple = (128, 128),
96 patch_stride: tuple = None,
97 ):
98 self.root = root
100 # alos2_url = "https://ietr-lab.univ-rennes1.fr/polsarpro-bio/san-francisco/dataset/SAN_FRANCISCO_ALOS2.zip"
101 # labels_url = "https://raw.githubusercontent.com/liuxuvip/PolSF/master/SF-ALOS2/SF-ALOS2-label2d.png"
103 crop_coordinates = ((2832, 736), (7888, 3520))
104 root = pathlib.Path(root) / "VOL-ALOS2044980750-150324-HBQR1.1__A"
105 self.alos_dataset = ALOSDataset(
106 root, transform, crop_coordinates, patch_size, patch_stride
107 )
108 if isinstance(root, str):
109 root = pathlib.Path(root)
110 self.labels = np.array(Image.open(root.parent / "SF-ALOS2-label2d.png"))[
111 ::-1, :
112 ].copy() # copy necessary as otherwise torch.from_numpy does not support
113 # negative stride
115 def __len__(self) -> int:
116 """
117 Returns the total number of patches in the while image.
119 Returns:
120 the total number of patches in the dataset
121 """
122 return len(self.alos_dataset)
124 def __getitem__(self, idx) -> Tuple[Any, Any]:
125 """
126 Returns the indexes patch.
128 Arguments:
129 idx (int): Index
131 Returns:
132 tuple: (patch, labels) where patch contains the 4 complex valued polarization HH, HV, VH, VV and labels contains the aligned semantic labels
133 """
134 alos_patch = self.alos_dataset[idx]
136 row_stride, col_stride = self.alos_dataset.patch_stride
137 start_row = (idx // self.alos_dataset.nsamples_per_cols) * row_stride
139 start_col = (idx % self.alos_dataset.nsamples_per_cols) * col_stride
140 num_rows, num_cols = self.alos_dataset.patch_size
141 labels = self.labels[
142 start_row : (start_row + num_rows), start_col : (start_col + num_cols)
143 ]
145 return alos_patch, labels