Source code for torchcvnn.datasets.bretigny.dataset

# MIT License

# Copyright (c) 2024 Chengfang Ren, Jeremy Fix

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Standard imports
import pathlib
from typing import Tuple, Any

# External imports
from torch.utils.data import Dataset
import numpy as np


[docs] class Bretigny(Dataset): r""" Bretigny Dataset Arguments: root: the root directory containing the npz files for Bretigny fold: train (70%), valid (15%), or test (15%) transform : the transform applied the cropped image balanced: whether or not to use balanced labels patch_size: the dimensions of the patches to consider (rows, cols) patch_stride: the shift between two consecutive patches, default:patch_size Note: An example usage : .. code-block:: python import torchcvnn from torchcvnn.datasets import Bretigny dataset = Bretigny( rootdir, fold="train", patch_size=((128, 128)), transform=lambda x: np.abs(x) ) X, y = dataset[0] Displayed below are the train, valid and test parts with the labels overlayed .. image:: ../assets/datasets/bretigny_train.png :alt: Train fold :width: 60% .. image:: ../assets/datasets/bretigny_valid.png :alt: Valid fold :width: 15% .. image:: ../assets/datasets/bretigny_test.png :alt: Test fold :width: 15% """ """ Class names """ classes = ["0 - Unlabeld", "1 - Forest", "2 - Track", "3 - Urban", "4 - Fields"] def __init__( self, root: str, fold: str, transform=None, balanced: bool = False, patch_size: tuple = (128, 128), patch_stride: tuple = None, ): self.root = pathlib.Path(root) self.transform = transform self.patch_size = patch_size self.patch_stride = patch_stride if patch_stride is None: self.patch_stride = patch_size # Preload the data sar_filename = self.root / "bretigny_seg.npz" if not sar_filename.exists(): raise RuntimeError(f"Cannot find the file {sar_filename}") sar_data = np.load(sar_filename) self.HH, self.HV, self.VV = sar_data["HH"], sar_data["HV"], sar_data["VV"] if balanced: label_filename = self.root / "bretigny_seg_4ROI_balanced.npz" else: label_filename = self.root / "bretigny_seg_4ROI.npz" if not label_filename.exists(): raise RuntimeError(f"Cannot find the label file {label_filename}") self.labels = np.load(label_filename)["arr_0"] if not fold in ["train", "valid", "test"]: raise ValueError( f"Unrecognized fold {fold}. Should be either train, valid or test" ) # Crop the data with respect to the fold if fold == "train": col_start = 0 col_end = int(0.70 * self.HH.shape[1]) elif fold == "valid": col_start = int(0.70 * self.HH.shape[1]) + 1 col_end = int(0.85 * self.HH.shape[1]) else: col_start = int(0.85 * self.HH.shape[1]) + 1 col_end = self.HH.shape[1] self.HH = self.HH[:, col_start:col_end] self.HV = self.HV[:, col_start:col_end] self.VV = self.VV[:, col_start:col_end] self.labels = self.labels[:, col_start:col_end] # Precompute the dimension of the grid of patches nrows = self.HH.shape[0] ncols = self.HH.shape[1] nrows_patch, ncols_patch = self.patch_size row_stride, col_stride = self.patch_stride self.nsamples_per_rows = (nrows - nrows_patch) // row_stride + 1 self.nsamples_per_cols = (ncols - ncols_patch) // col_stride + 1 def __len__(self) -> int: """ Returns the total number of patches in the whole image. Returns: the total number of patches in the dataset """ return self.nsamples_per_rows * self.nsamples_per_cols def __getitem__(self, idx) -> Tuple[Any, Any]: """ Returns the indexes patch. Arguments: idx (int): Index Returns: tuple: (patch, labels) where patch contains the 3 complex valued polarization HH, HV, VV and labels contains the aligned semantic labels """ row_stride, col_stride = self.patch_stride start_row = (idx // self.nsamples_per_cols) * row_stride start_col = (idx % self.nsamples_per_cols) * col_stride num_rows, num_cols = self.patch_size patches = [ patch[ start_row : (start_row + num_rows), start_col : (start_col + num_cols) ] for patch in [self.HH, self.HV, self.VV] ] patches = np.stack(patches) if self.transform is not None: patches = self.transform(patches) labels = self.labels[ start_row : (start_row + num_rows), start_col : (start_col + num_cols) ] return patches, labels