Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/datasets/miccai2023.py: 0%
82 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# coding: utf-8
3# MIT License
5# Copyright (c) 2024 Jeremy Fix
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
14# The above copyright notice and this permission notice shall be included in
15# all copies or substantial portions of the Software.
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25# Standard imports
26from enum import Enum
27import pathlib
28import logging
29from typing import Union
31# External imports
32import torch
33from torch.utils.data import Dataset
34import h5py # Required because the data are matlab v7.3 files
35import numpy as np
38class CINEView(Enum):
39 SAX = 1
40 LAX = 2
43class AccFactor(Enum):
44 ACC4 = 4
45 ACC8 = 8
46 ACC10 = 10
49def load_matlab_file(filename: str, key: str) -> np.ndarray:
50 """
51 Load a matlab file in HDF5 format
52 """
53 with h5py.File(filename, "r") as f:
54 logging.debug(f"Got the keys {f.keys()} from {filename}")
55 data = f[key][()]
56 return data
59def kspace_to_image(kspace: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
60 """
61 Convert k-space data to image data. The returned kspace is
62 of the same type than the the provided image (np.ndarray or torch.Tensor).
64 Arguments:
65 kspace : torch.Tensor or np.ndarray
66 k-space data
68 Returns:
69 torch.Tensor or np.ndarray
70 image data
71 """
72 if isinstance(kspace, torch.Tensor):
73 img = torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(kspace)))
74 else:
75 img = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kspace)))
76 return img
79def image_to_kspace(
80 img: Union[torch.Tensor, np.ndarray]
81) -> Union[torch.Tensor, np.ndarray]:
82 """
83 Convert image data to k-space data. The returned kspace is
84 of the same type than the the provided image (np.ndarray or torch.Tensor)
86 Arguments:
87 img : torch.Tensor or np.ndarray
88 Image data
90 Returns:
91 torch.Tensor or np.ndarray
92 k-space data
94 """
95 if isinstance(img, torch.Tensor):
96 kspace = torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(img)))
97 else:
98 kspace = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(img)))
99 return kspace
102def combine_coils_from_kspace(kspace: np.ndarray) -> np.ndarray:
103 """
104 Combine the coils of the k-space data using the root sum of squares
106 Arguments:
107 kspace : np.ndarray
108 k-space data of shape (sc, ky, kx)
110 Returns:
111 np.ndarray
112 Image data with coils combined, of shape (ky, kx), real valued, positive
113 """
114 if kspace.ndim != 3:
115 raise ValueError(
116 f"kspace should have 3 dimensions, got {kspace.ndim}. Expected dimensions (sc, ky, kx)"
117 )
118 images = np.fft.ifft2(np.fft.ifftshift(kspace))
119 return np.fft.fftshift(np.sqrt(np.sum(np.abs(images) ** 2, axis=0)))
122class MICCAI2023(Dataset):
123 """
124 Loads the MICCAI2023 challenge data for the reconstruction task Task 1
126 The data are described on https://cmrxrecon.github.io/Task1-Cine-reconstruction.html
128 You need to download the data before hand in order to use this class.
130 For loading the data, you may want to alternatively consider the fastmri library, see https://github.com/facebookresearch/fastMRI/
132 The structure of the dataset is as follows:
134 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/P{id}/
135 - cine_sax.mat
136 - cin_lax.mat
137 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor04/P{id}/
138 - cine_sax.mat
139 - cine_sax_mask.mat
140 - cin_lax.mat
141 - cine_lax_mask.mat
142 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor08/P{id}/
143 - cine_sax.mat
144 - cine_sax_mask.mat
145 - cin_lax.mat
146 - cine_lax_mask.mat
147 rootdir/ChallengeData/MultiCoil/cine/TrainingSet/AccFactor10/P{id}/
148 - cine_sax.mat
149 - cine_sax_mask.mat
150 - cin_lax.mat
151 - cine_lax_mask.mat
153 The cine_sax or sine_lax files are :math:`(k_x, k_y, s_c, s_z, t)` where :
155 - :math:`k_x`: matrix size in x-axis (k-space)
156 - :math:`k_y``: matrix size in y-axis (k-space)
157 - :math:`s_c`: coil array number (compressed to 10)
158 - :math:`s_x`: matrix size in x-axis (image)
159 - :math:`s_y`: matrix size in y-axis (image) , used in single-coil data
160 - :math:`s_z`: slice number for short axis view, or slice group for long axis (i.e., 3ch, 2ch and 4ch views)
161 - :math:`t`: time frame.
163 Note the k-space dimensions (in x/y axis) are not the same depending on the patient.
165 This is a recontruction dataset. The goal is to reconstruct the fully sampled k-space
166 from the subsampled k-space. The acceleratation factor specifies the subsampling rate.
168 There are also the Single-Coil data which is not yet considered by this implementation
170 Note:
171 An example usage :
173 .. code-block:: python
175 import torchcvnn
176 from torchcvnn.datasets.miccai2023 import MICCAI2023, CINEView, AccFactor
178 def process_kspace(kspace, coil_idx, slice_idx, frame_idx):
179 coil_kspace = kspace[:, :, coil_idx, slice_idx, frame_idx]
180 mod_kspace = np.log(np.abs(coil_kspace) + 1e-9)
182 img = kspace_to_image(coil_kspace)
183 img = np.abs(img)
184 img = img / img.max()
186 return mod_kspace, img
188 dataset = MICCAI2023(rootdir, view=CINEView.SAX, acc_factor=AccFactor.ACC8)
189 subsampled_kspace, subsampled_mask, full_kspace = dataset[0]
191 frame_idx = 5
192 slice_idx = 0
193 coil_idx = 9
195 mod_full, img_full = process_kspace(full_kspace, coil_idx, slice_idx, frame_idx)
196 mod_sub, img_sub = process_kspace(subsampled_kspace, coil_idx, slice_idx, frame_idx)
198 # Plot the above magnitudes
199 ...
202 Displayed below is an example patient with the SAX view and acceleration of 8:
204 .. figure:: ../assets/datasets/miccai2023_sax8.png
205 :alt: Example patient from the MICCAI2023 dataset with both the full sampled and under sampled k-space and images
206 :width: 100%
207 :align: center
209 Displayed below is an example patient with the LAX view and acceleration of 4:
211 .. figure:: ../assets/datasets/miccai2023_lax4.png
212 :alt: Example patient from the MICCAI2023 dataset with both the full sampled and under sampled k-space and images
213 :width: 100%
214 :align: center
216 You can combine the coils using the root sum of squares
217 to get a magnitude image (real valued) with all the
218 coil contributions.
221 Below are examples combining the coils for a given
222 frame and slice, for LAX (top) and SAX (bottom). It uses
223 the function :py:func:`torchcvnn.datasets.miccai2023.combine_coils_from_kspace`
225 .. figure:: ../assets/datasets/miccai2023_combined_lax.png
226 :alt: Example LAX, combining the coils
227 :width: 50%
228 :align: center
230 .. figure:: ../assets/datasets/miccai2023_combined_sax.png
231 :alt: Example SAX, combining the coils
232 :width: 50%
233 :align: center
235 """
237 def __init__(
238 self,
239 rootdir: str,
240 view: CINEView = CINEView.SAX,
241 acc_factor: AccFactor = AccFactor.ACC4,
242 ):
243 self.rootdir = pathlib.Path(rootdir)
245 if view == CINEView.SAX:
246 self.input_filename = "cine_sax.mat"
247 self.mask_filename = "cine_sax_mask.mat"
248 elif view == CINEView.LAX:
249 self.input_filename = "cine_lax.mat"
250 self.mask_filename = "cine_lax_mask.mat"
252 # List all the available data
253 self.fullsampled_rootdir = self.rootdir / "MultiCoil" / "cine" / "TrainingSet"
254 self.fullsampled_key = "kspace_full"
255 self.subsampled_rootdir = (
256 self.rootdir
257 / "MultiCoil"
258 / "cine"
259 / "TrainingSet"
260 / f"AccFactor{acc_factor.value:02d}"
261 )
262 self.subsampled_key = f"kspace_sub{acc_factor.value:02d}"
263 self.mask_key = f"mask{acc_factor.value:02d}"
265 logging.info(f"Loading data from {self.subsampled_rootdir}")
267 # We list all the patients in the subsampled data directory
268 # and check we have the data, mask and full sampled data
269 self.patients = []
270 for patient in self.subsampled_rootdir.iterdir():
271 if not patient.is_dir():
272 continue
274 if not (patient / self.input_filename).exists():
275 logging.warning(f"Missing {self.input_filename} for patient {patient}")
276 continue
278 if not (patient / self.mask_filename).exists():
279 logging.warning(f"Missing {self.mask_filename} for patient {patient}")
280 continue
282 fullsampled_patient = self.fullsampled_rootdir / patient.name
283 if not (fullsampled_patient / self.input_filename).exists():
284 logging.warning(
285 f"Missing {self.input_filename} for patient {fullsampled_patient}"
286 )
287 continue
289 self.patients.append(patient)
291 logging.debug(
292 f"I found {len(self.patients)} patient(s) : {[p.name for p in self.patients]}"
293 )
295 def __len__(self):
296 return len(self.patients)
298 def __getitem__(self, idx):
299 """
300 Returns the subsampled k-space data, the mask and the fully sampled k-space data
301 """
302 patient = self.patients[idx]
304 subsampled_data = None
305 subsampled_mask = None
306 fullsampled_data = None
308 # Load the subsampled data
309 logging.info(f"Loading {patient / self.input_filename}")
310 subsampled_data = load_matlab_file(
311 patient / self.input_filename, self.subsampled_key
312 ).transpose(3, 4, 2, 1, 0)
313 subsampled_data = subsampled_data["real"] + 1j * subsampled_data["imag"]
314 # (kx, ky, sc, sz, t) for multi-coil data
315 # e.g. (246, 512, 10, 10, 12)
317 logging.info(f"Loading {patient / self.mask_filename}")
318 subsampled_mask = load_matlab_file(
319 patient / self.mask_filename, self.mask_key
320 ).transpose(0, 1)
321 # (kx, ky)
322 # e.g. (246, 512)
324 logging.info(
325 f"Loading {self.fullsampled_rootdir / patient.name / self.input_filename}"
326 )
327 fullsampled_data = load_matlab_file(
328 self.fullsampled_rootdir / patient.name / self.input_filename,
329 self.fullsampled_key,
330 ).transpose(3, 4, 2, 1, 0)
331 fullsampled_data = fullsampled_data["real"] + 1j * fullsampled_data["imag"]
332 # kx, ky, sc, sz, t
333 # e.g. (246, 512, 10, 10, 12)
335 return subsampled_data, subsampled_mask, fullsampled_data