Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/miccai2023.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# coding: utf-8 

2 

3# MIT License 

4 

5# Copyright (c) 2024 Jeremy Fix 

6 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13 

14# The above copyright notice and this permission notice shall be included in 

15# all copies or substantial portions of the Software. 

16 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25# Standard imports 

26from enum import Enum 

27import pathlib 

28import logging 

29from typing import Union 

30 

31# External imports 

32import torch 

33from torch.utils.data import Dataset 

34import h5py # Required because the data are matlab v7.3 files 

35import numpy as np 

36 

37 

38class CINEView(Enum): 

39 SAX = 1 

40 LAX = 2 

41 

42 

43class AccFactor(Enum): 

44 ACC4 = 4 

45 ACC8 = 8 

46 ACC10 = 10 

47 

48 

49def load_matlab_file(filename: str, key: str) -> np.ndarray: 

50 """ 

51 Load a matlab file in HDF5 format 

52 """ 

53 with h5py.File(filename, "r") as f: 

54 logging.debug(f"Got the keys {f.keys()} from {filename}") 

55 data = f[key][()] 

56 return data 

57 

58 

59def kspace_to_image(kspace: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: 

60 """ 

61 Convert k-space data to image data. The returned kspace is 

62 of the same type than the the provided image (np.ndarray or torch.Tensor). 

63 

64 Arguments: 

65 kspace : torch.Tensor or np.ndarray 

66 k-space data 

67 

68 Returns: 

69 torch.Tensor or np.ndarray 

70 image data 

71 """ 

72 if isinstance(kspace, torch.Tensor): 

73 img = torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(kspace))) 

74 else: 

75 img = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace))) 

76 return img 

77 

78 

79def image_to_kspace( 

80 img: Union[torch.Tensor, np.ndarray] 

81) -> Union[torch.Tensor, np.ndarray]: 

82 """ 

83 Convert image data to k-space data. The returned kspace is 

84 of the same type than the the provided image (np.ndarray or torch.Tensor) 

85 

86 Arguments: 

87 img : torch.Tensor or np.ndarray 

88 Image data 

89 

90 Returns: 

91 torch.Tensor or np.ndarray 

92 k-space data 

93 

94 """ 

95 if isinstance(img, torch.Tensor): 

96 kspace = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(img))) 

97 else: 

98 kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img))) 

99 return kspace 

100 

101 

102def combine_coils_from_kspace(kspace: np.ndarray) -> np.ndarray: 

103 """ 

104 Combine the coils of the k-space data using the root sum of squares 

105 

106 Arguments: 

107 kspace : np.ndarray 

108 k-space data of shape (sc, ky, kx) 

109 

110 Returns: 

111 np.ndarray 

112 Image data with coils combined, of shape (ky, kx), real valued, positive 

113 """ 

114 if kspace.ndim != 3: 

115 raise ValueError( 

116 f"kspace should have 3 dimensions, got {kspace.ndim}. Expected dimensions (sc, ky, kx)" 

117 ) 

118 images = np.fft.ifft2(np.fft.ifftshift(kspace)) 

119 return np.fft.fftshift(np.sqrt(np.sum(np.abs(images) ** 2, axis=0))) 

120 

121 

122class MICCAI2023(Dataset): 

123 """ 

124 Loads the MICCAI2023 challenge data for the reconstruction task Task 1 

125 

126 The data are described on https://cmrxrecon.github.io/Task1-Cine-reconstruction.html 

127 

128 You need to download the data before hand in order to use this class. 

129 

130 For loading the data, you may want to alternatively consider the fastmri library, see https://github.com/facebookresearch/fastMRI/ 

131 

132 The structure of the dataset is as follows: 

133 

134 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/P{id}/ 

135 - cine_sax.mat 

136 - cin_lax.mat 

137 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor04/P{id}/ 

138 - cine_sax.mat 

139 - cine_sax_mask.mat 

140 - cin_lax.mat 

141 - cine_lax_mask.mat 

142 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor08/P{id}/ 

143 - cine_sax.mat 

144 - cine_sax_mask.mat 

145 - cin_lax.mat 

146 - cine_lax_mask.mat 

147 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor10/P{id}/ 

148 - cine_sax.mat 

149 - cine_sax_mask.mat 

150 - cin_lax.mat 

151 - cine_lax_mask.mat 

152 

153 The cine_sax or sine_lax files are :math:`(k_x, k_y, s_c, s_z, t)` where : 

154 

155 - :math:`k_x`: matrix size in x-axis (k-space) 

156 - :math:`k_y``: matrix size in y-axis (k-space) 

157 - :math:`s_c`: coil array number (compressed to 10) 

158 - :math:`s_x`: matrix size in x-axis (image) 

159 - :math:`s_y`: matrix size in y-axis (image) , used in single-coil data 

160 - :math:`s_z`: slice number for short axis view, or slice group for long axis (i.e., 3ch, 2ch and 4ch views) 

161 - :math:`t`: time frame. 

162 

163 Note the k-space dimensions (in x/y axis) are not the same depending on the patient. 

164 

165 This is a recontruction dataset. The goal is to reconstruct the fully sampled k-space 

166 from the subsampled k-space. The acceleratation factor specifies the subsampling rate. 

167 

168 There are also the Single-Coil data which is not yet considered by this implementation 

169 

170 Note: 

171 An example usage : 

172 

173 .. code-block:: python 

174 

175 import torchcvnn 

176 from torchcvnn.datasets.miccai2023 import MICCAI2023, CINEView, AccFactor 

177 

178 def process_kspace(kspace, coil_idx, slice_idx, frame_idx): 

179 coil_kspace = kspace[:, :, coil_idx, slice_idx, frame_idx] 

180 mod_kspace = np.log(np.abs(coil_kspace) + 1e-9) 

181 

182 img = kspace_to_image(coil_kspace) 

183 img = np.abs(img) 

184 img = img / img.max() 

185 

186 return mod_kspace, img 

187 

188 dataset = MICCAI2023(rootdir, view=CINEView.SAX, acc_factor=AccFactor.ACC8) 

189 subsampled_kspace, subsampled_mask, full_kspace = dataset[0] 

190 

191 frame_idx = 5 

192 slice_idx = 0 

193 coil_idx = 9 

194 

195 mod_full, img_full = process_kspace(full_kspace, coil_idx, slice_idx, frame_idx) 

196 mod_sub, img_sub = process_kspace(subsampled_kspace, coil_idx, slice_idx, frame_idx) 

197 

198 # Plot the above magnitudes 

199 ... 

200 

201 

202 Displayed below is an example patient with the SAX view and acceleration of 8: 

203 

204 .. figure:: ../assets/datasets/miccai2023_sax8.png 

205 :alt: Example patient from the MICCAI2023 dataset with both the full sampled and under sampled k-space and images 

206 :width: 100% 

207 :align: center 

208 

209 Displayed below is an example patient with the LAX view and acceleration of 4: 

210 

211 .. figure:: ../assets/datasets/miccai2023_lax4.png 

212 :alt: Example patient from the MICCAI2023 dataset with both the full sampled and under sampled k-space and images 

213 :width: 100% 

214 :align: center 

215 

216 You can combine the coils using the root sum of squares 

217 to get a magnitude image (real valued) with all the 

218 coil contributions. 

219 

220 

221 Below are examples combining the coils for a given 

222 frame and slice, for LAX (top) and SAX (bottom). It uses 

223 the function :py:func:`torchcvnn.datasets.miccai2023.combine_coils_from_kspace` 

224 

225 .. figure:: ../assets/datasets/miccai2023_combined_lax.png 

226 :alt: Example LAX, combining the coils 

227 :width: 50% 

228 :align: center 

229 

230 .. figure:: ../assets/datasets/miccai2023_combined_sax.png 

231 :alt: Example SAX, combining the coils 

232 :width: 50% 

233 :align: center 

234 

235 """ 

236 

237 def __init__( 

238 self, 

239 rootdir: str, 

240 view: CINEView = CINEView.SAX, 

241 acc_factor: AccFactor = AccFactor.ACC4, 

242 ): 

243 self.rootdir = pathlib.Path(rootdir) 

244 

245 if view == CINEView.SAX: 

246 self.input_filename = "cine_sax.mat" 

247 self.mask_filename = "cine_sax_mask.mat" 

248 elif view == CINEView.LAX: 

249 self.input_filename = "cine_lax.mat" 

250 self.mask_filename = "cine_lax_mask.mat" 

251 

252 # List all the available data 

253 self.fullsampled_rootdir = self.rootdir / "MultiCoil" / "cine" / "TrainingSet" 

254 self.fullsampled_key = "kspace_full" 

255 self.subsampled_rootdir = ( 

256 self.rootdir 

257 / "MultiCoil" 

258 / "cine" 

259 / "TrainingSet" 

260 / f"AccFactor{acc_factor.value:02d}" 

261 ) 

262 self.subsampled_key = f"kspace_sub{acc_factor.value:02d}" 

263 self.mask_key = f"mask{acc_factor.value:02d}" 

264 

265 logging.info(f"Loading data from {self.subsampled_rootdir}") 

266 

267 # We list all the patients in the subsampled data directory 

268 # and check we have the data, mask and full sampled data 

269 self.patients = [] 

270 for patient in self.subsampled_rootdir.iterdir(): 

271 if not patient.is_dir(): 

272 continue 

273 

274 if not (patient / self.input_filename).exists(): 

275 logging.warning(f"Missing {self.input_filename} for patient {patient}") 

276 continue 

277 

278 if not (patient / self.mask_filename).exists(): 

279 logging.warning(f"Missing {self.mask_filename} for patient {patient}") 

280 continue 

281 

282 fullsampled_patient = self.fullsampled_rootdir / patient.name 

283 if not (fullsampled_patient / self.input_filename).exists(): 

284 logging.warning( 

285 f"Missing {self.input_filename} for patient {fullsampled_patient}" 

286 ) 

287 continue 

288 

289 self.patients.append(patient) 

290 

291 logging.debug( 

292 f"I found {len(self.patients)} patient(s) : {[p.name for p in self.patients]}" 

293 ) 

294 

295 def __len__(self): 

296 return len(self.patients) 

297 

298 def __getitem__(self, idx): 

299 """ 

300 Returns the subsampled k-space data, the mask and the fully sampled k-space data 

301 """ 

302 patient = self.patients[idx] 

303 

304 subsampled_data = None 

305 subsampled_mask = None 

306 fullsampled_data = None 

307 

308 # Load the subsampled data 

309 logging.info(f"Loading {patient / self.input_filename}") 

310 subsampled_data = load_matlab_file( 

311 patient / self.input_filename, self.subsampled_key 

312 ).transpose(3, 4, 2, 1, 0) 

313 subsampled_data = subsampled_data["real"] + 1j * subsampled_data["imag"] 

314 # (kx, ky, sc, sz, t) for multi-coil data 

315 # e.g. (246, 512, 10, 10, 12) 

316 

317 logging.info(f"Loading {patient / self.mask_filename}") 

318 subsampled_mask = load_matlab_file( 

319 patient / self.mask_filename, self.mask_key 

320 ).transpose(0, 1) 

321 # (kx, ky) 

322 # e.g. (246, 512) 

323 

324 logging.info( 

325 f"Loading {self.fullsampled_rootdir / patient.name / self.input_filename}" 

326 ) 

327 fullsampled_data = load_matlab_file( 

328 self.fullsampled_rootdir / patient.name / self.input_filename, 

329 self.fullsampled_key, 

330 ).transpose(3, 4, 2, 1, 0) 

331 fullsampled_data = fullsampled_data["real"] + 1j * fullsampled_data["imag"] 

332 # kx, ky, sc, sz, t 

333 # e.g. (246, 512, 10, 10, 12) 

334 

335 return subsampled_data, subsampled_mask, fullsampled_data