Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/bretigny/dataset.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Chengfang Ren, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24import pathlib 

25from typing import Tuple, Any 

26 

27# External imports 

28from torch.utils.data import Dataset 

29import numpy as np 

30 

31 

32class Bretigny(Dataset): 

33 r""" 

34 Bretigny Dataset 

35 

36 Arguments: 

37 root: the root directory containing the npz files for Bretigny 

38 fold: train (70%), valid (15%), or test (15%) 

39 transform : the transform applied the cropped image 

40 balanced: whether or not to use balanced labels 

41 patch_size: the dimensions of the patches to consider (rows, cols) 

42 patch_stride: the shift between two consecutive patches, default:patch_size 

43 

44 Note: 

45 An example usage : 

46 

47 .. code-block:: python 

48 

49 import torchcvnn 

50 from torchcvnn.datasets import Bretigny 

51 

52 dataset = Bretigny( 

53 rootdir, fold="train", patch_size=((128, 128)), transform=lambda x: np.abs(x) 

54 ) 

55 X, y = dataset[0] 

56 

57 Displayed below are the train, valid and test parts with the labels overlayed 

58 

59 .. image:: ../assets/datasets/bretigny_train.png 

60 :alt: Train fold 

61 :width: 60% 

62 .. image:: ../assets/datasets/bretigny_valid.png 

63 :alt: Valid fold 

64 :width: 15% 

65 .. image:: ../assets/datasets/bretigny_test.png 

66 :alt: Test fold 

67 :width: 15% 

68 

69 """ 

70 

71 """ 

72 Class names 

73 """ 

74 classes = ["0 - Unlabeld", "1 - Forest", "2 - Track", "3 - Urban", "4 - Fields"] 

75 

76 def __init__( 

77 self, 

78 root: str, 

79 fold: str, 

80 transform=None, 

81 balanced: bool = False, 

82 patch_size: tuple = (128, 128), 

83 patch_stride: tuple = None, 

84 ): 

85 self.root = pathlib.Path(root) 

86 self.transform = transform 

87 

88 self.patch_size = patch_size 

89 self.patch_stride = patch_stride 

90 if patch_stride is None: 

91 self.patch_stride = patch_size 

92 

93 # Preload the data 

94 sar_filename = self.root / "bretigny_seg.npz" 

95 if not sar_filename.exists(): 

96 raise RuntimeError(f"Cannot find the file {sar_filename}") 

97 sar_data = np.load(sar_filename) 

98 self.HH, self.HV, self.VV = sar_data["HH"], sar_data["HV"], sar_data["VV"] 

99 

100 if balanced: 

101 label_filename = self.root / "bretigny_seg_4ROI_balanced.npz" 

102 else: 

103 label_filename = self.root / "bretigny_seg_4ROI.npz" 

104 if not label_filename.exists(): 

105 raise RuntimeError(f"Cannot find the label file {label_filename}") 

106 self.labels = np.load(label_filename)["arr_0"] 

107 

108 if not fold in ["train", "valid", "test"]: 

109 raise ValueError( 

110 f"Unrecognized fold {fold}. Should be either train, valid or test" 

111 ) 

112 

113 # Crop the data with respect to the fold 

114 if fold == "train": 

115 col_start = 0 

116 col_end = int(0.70 * self.HH.shape[1]) 

117 elif fold == "valid": 

118 col_start = int(0.70 * self.HH.shape[1]) + 1 

119 col_end = int(0.85 * self.HH.shape[1]) 

120 else: 

121 col_start = int(0.85 * self.HH.shape[1]) + 1 

122 col_end = self.HH.shape[1] 

123 

124 self.HH = self.HH[:, col_start:col_end] 

125 self.HV = self.HV[:, col_start:col_end] 

126 self.VV = self.VV[:, col_start:col_end] 

127 self.labels = self.labels[:, col_start:col_end] 

128 

129 # Precompute the dimension of the grid of patches 

130 nrows = self.HH.shape[0] 

131 ncols = self.HH.shape[1] 

132 

133 nrows_patch, ncols_patch = self.patch_size 

134 row_stride, col_stride = self.patch_stride 

135 

136 self.nsamples_per_rows = (nrows - nrows_patch) // row_stride + 1 

137 self.nsamples_per_cols = (ncols - ncols_patch) // col_stride + 1 

138 

139 def __len__(self) -> int: 

140 """ 

141 Returns the total number of patches in the whole image. 

142 

143 Returns: 

144 the total number of patches in the dataset 

145 """ 

146 return self.nsamples_per_rows * self.nsamples_per_cols 

147 

148 def __getitem__(self, idx) -> Tuple[Any, Any]: 

149 """ 

150 Returns the indexes patch. 

151 

152 Arguments: 

153 idx (int): Index 

154 

155 Returns: 

156 tuple: (patch, labels) where patch contains the 3 complex valued polarization HH, HV, VV and labels contains the aligned semantic labels 

157 """ 

158 row_stride, col_stride = self.patch_stride 

159 start_row = (idx // self.nsamples_per_cols) * row_stride 

160 start_col = (idx % self.nsamples_per_cols) * col_stride 

161 num_rows, num_cols = self.patch_size 

162 patches = [ 

163 patch[ 

164 start_row : (start_row + num_rows), start_col : (start_col + num_cols) 

165 ] 

166 for patch in [self.HH, self.HV, self.VV] 

167 ] 

168 patches = np.stack(patches) 

169 if self.transform is not None: 

170 patches = self.transform(patches) 

171 

172 labels = self.labels[ 

173 start_row : (start_row + num_rows), start_col : (start_col + num_cols) 

174 ] 

175 

176 return patches, labels