Classification of MNIST digit from their Fourier representation

In this tutorial, we show how to create a complex valued neural network and to optimize it. As a matter of illustration, suppose you want to classify the MNIST digits from their Fourier space representation.

Loading the complex valued MNIST dataset

Using torchvision.datasets.MNIST and torch.fft.fft(), we can easily transform the MNIST dataset into its Fourier space representation :

dataset = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)

Implementing the complex valued convolutional neural network

To implement a complex valued CNN

cdtype = torch.complex64

conv_model = nn.Sequential(
    *conv_block(1, 16, cdtype),
    *conv_block(16, 16, cdtype),
    *conv_block(16, 32, cdtype),
    *conv_block(32, 32, cdtype),
    nn.Flatten(),
)

with conv_block() defined as :

def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
    """
    Builds a basic building block of
    `Conv2d`-`Cardioid`-`Conv2d`-`Cardioid`-`AvgPool2d`

    Arguments:
        in_c : the number of input channels
        out_c : the number of output channels
        cdtype : the dtype of complex values (expected to be torch.complex64 or torch.complex32)
    """
    return [
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
    ]

And then we use a standard optimizer, training loop, cross entropy loss function, etc…

The full code is available to download mnist.py and completly given below. To run the code, you also need the utils.py file which provides some utilitary functions. Finally, additional dependencies other than torchcvnn are needed :

python3 -m pip install tochvision tqdm

If you run that script, the expected output should be :

Logging to ./logs/CMNIST_0
>> Training
100%|██████| 844/844 [00:17<00:00, 48.61it/s]
>> Testing
[Step 0] Train : CE  0.20 Acc  0.94 | Valid : CE  0.08 Acc  0.97 | Test : CE 0.06 Acc  0.98[>> BETTER <<]

>> Training
100%|██████| 844/844 [00:16<00:00, 51.69it/s]
>> Testing
[Step 1] Train : CE  0.06 Acc  0.98 | Valid : CE  0.06 Acc  0.98 | Test : CE 0.05 Acc  0.98[>> BETTER <<]

>> Training
100%|██████| 844/844 [00:15<00:00, 53.47it/s]
>> Testing
[Step 2] Train : CE  0.04 Acc  0.99 | Valid : CE  0.04 Acc  0.99 | Test : CE 0.04 Acc  0.99[>> BETTER <<]

[...]

Complete code

mnist.py

# MIT License

# Copyright (c) 2023 Jérémy Fix

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
# Example using complex valued neural networks to classify MNIST from the Fourier Transform of the digits.



Requires dependencies :
    python3 -m pip install torchvision tqdm
"""

# Standard imports
import random
import sys
from typing import List

# External imports
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2_transforms

import torchcvnn.nn as c_nn

# Local imports
import utils


def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
    """
    Builds a basic building block of
    `Conv2d`-`Cardioid`-`Conv2d`-`Cardioid`-`AvgPool2d`

    Arguments:
        in_c : the number of input channels
        out_c : the number of output channels
        cdtype : the dtype of complex values (expected to be torch.complex64 or torch.complex32)
    """
    return [
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
        c_nn.BatchNorm2d(out_c),
        c_nn.Cardioid(),
        c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
    ]


def train():
    """
    Train function

    Sample output :
        ```.bash
        (venv) me@host:~$ python mnist.py
        Logging to ./logs/CMNIST_0
        >> Training
        100%|██████| 844/844 [00:17<00:00, 48.61it/s]
        >> Testing
        [Step 0] Train : CE  0.20 Acc  0.94 | Valid : CE  0.08 Acc  0.97 | Test : CE 0.06 Acc  0.98[>> BETTER <<]

        >> Training
        100%|██████| 844/844 [00:16<00:00, 51.69it/s]
        >> Testing
        [Step 1] Train : CE  0.06 Acc  0.98 | Valid : CE  0.06 Acc  0.98 | Test : CE 0.05 Acc  0.98[>> BETTER <<]

        >> Training
        100%|██████| 844/844 [00:15<00:00, 53.47it/s]
        >> Testing
        [Step 2] Train : CE  0.04 Acc  0.99 | Valid : CE  0.04 Acc  0.99 | Test : CE 0.04 Acc  0.99[>> BETTER <<]

        [...]
        ```

    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    valid_ratio = 0.1
    batch_size = 64
    epochs = 10
    cdtype = torch.complex64

    # Dataloading
    train_valid_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )
    test_dataset = torchvision.datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
    )

    all_indices = list(range(len(train_valid_dataset)))
    random.shuffle(all_indices)
    split_idx = int(valid_ratio * len(train_valid_dataset))
    valid_indices, train_indices = all_indices[:split_idx], all_indices[split_idx:]

    # Train dataloader
    train_dataset = torch.utils.data.Subset(train_valid_dataset, train_indices)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    # Valid dataloader
    valid_dataset = torch.utils.data.Subset(train_valid_dataset, valid_indices)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, shuffle=False
    )

    # Test dataloader
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )

    # Model
    conv_model = nn.Sequential(
        *conv_block(1, 16, cdtype),
        *conv_block(16, 16, cdtype),
        *conv_block(16, 32, cdtype),
        *conv_block(32, 32, cdtype),
        nn.Flatten(),
    )

    with torch.no_grad():
        conv_model.eval()
        dummy_input = torch.zeros((64, 1, 28, 28), dtype=cdtype, requires_grad=False)
        out_conv = conv_model(dummy_input).view(64, -1)
    lin_model = nn.Sequential(
        nn.Linear(out_conv.shape[-1], 124, dtype=cdtype),
        c_nn.Cardioid(),
        nn.Linear(124, 10, dtype=cdtype),
        c_nn.Mod(),
    )
    model = nn.Sequential(conv_model, lin_model)
    model.to(device)

    # Loss, optimizer, callbacks
    f_loss = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=3e-4)
    logpath = utils.generate_unique_logpath("./logs", "CMNIST")
    print(f"Logging to {logpath}")
    checkpoint = utils.ModelCheckpoint(model, logpath, 4, min_is_best=True)

    # Training loop
    for e in range(epochs):
        print(">> Training")
        train_loss, train_acc = utils.train_epoch(
            model, train_loader, f_loss, optim, device
        )

        print(">> Testing")
        valid_loss, valid_acc = utils.test_epoch(model, valid_loader, f_loss, device)
        test_loss, test_acc = utils.test_epoch(model, test_loader, f_loss, device)
        updated = checkpoint.update(valid_loss)
        better_str = "[>> BETTER <<]" if updated else ""

        print(
            f"[Step {e}] Train : CE {train_loss:5.2f} Acc {train_acc:5.2f} | Valid : CE {valid_loss:5.2f} Acc {valid_acc:5.2f} | Test : CE {test_loss:5.2f} Acc {test_acc:5.2f}"
            + better_str
        )


if __name__ == "__main__":
    train()

utils.py

# coding: utf-8
# MIT License

# Copyright (c) 2023 Jeremy Fix

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Standard imports
import os
from typing import Tuple

# External imports
import torch
import torch.nn as nn
import tqdm

# import torch.onnx


def train_epoch(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    f_loss: nn.Module,
    optim: torch.optim.Optimizer,
    device: torch.device,
) -> Tuple[float, float]:
    """
    Run the training loop for nsteps minibatches of the dataloader

    Arguments:
        model: the model to train
        loader: an iterable dataloader
        f_loss (nn.Module): the loss
        optim : an optimizing algorithm
        device: the device on which to run the code

    Returns:
        The averaged training loss
        The averaged training accuracy
    """
    model.train()

    loss_avg = 0
    acc_avg = 0
    num_samples = 0
    for inputs, outputs in tqdm.tqdm(loader):
        inputs = inputs.to(device)
        outputs = outputs.to(device)

        # Forward propagate through the model
        pred_outputs = model(inputs)

        # Forward propagate through the loss
        loss = f_loss(pred_outputs, outputs)

        # Backward pass and update
        optim.zero_grad()
        loss.backward()
        optim.step()

        num_samples += inputs.shape[0]

        # Denormalize the loss that is supposed to be averaged over the
        # minibatch
        loss_avg += inputs.shape[0] * loss.item()
        pred_cls = pred_outputs.argmax(dim=-1)
        acc_avg += (pred_cls == outputs).sum().item()

    return loss_avg / num_samples, acc_avg / num_samples


def test_epoch(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    f_loss: nn.Module,
    device: torch.device,
) -> Tuple[float, float]:
    """
    Run the test loop for n_test_batches minibatches of the dataloader

    Arguments:
        model: the model to evaluate
        loader: an iterable dataloader
        f_loss: the loss
        device: the device on which to run the code

    Returns:
        The averaged test loss
        The averaged test accuracy

    """
    model.eval()

    loss_avg = 0
    acc_avg = 0
    num_samples = 0
    with torch.no_grad():
        for inputs, outputs in loader:
            inputs = inputs.to(device)
            outputs = outputs.to(device)

            # Forward propagate through the model
            pred_outputs = model(inputs)

            # Forward propagate through the loss
            loss = f_loss(pred_outputs, outputs)

            loss_avg += inputs.shape[0] * loss.item()
            pred_cls = pred_outputs.argmax(dim=-1)
            acc_avg += (pred_cls == outputs).sum().item()
            num_samples += inputs.shape[0]

    return loss_avg / num_samples, acc_avg / num_samples


class ModelCheckpoint(object):
    def __init__(
        self,
        model: torch.nn.Module,
        savepath: str,
        num_input_dims: int,
        min_is_best: bool = True,
    ) -> None:
        """
        Early stopping callback

        Arguments:
            model: the model to save
            savepath: the location where to save the model's parameters
            num_input_dims: the number of dimensions for the input tensor (required for onnx export)
            min_is_best: whether the min metric or the max metric as the best
        """
        self.model = model
        self.savepath = savepath
        self.num_input_dims = num_input_dims
        self.best_score = None
        if min_is_best:
            self.is_better = self.lower_is_better
        else:
            self.is_better = self.higher_is_better

    def lower_is_better(self, score: float) -> bool:
        """
        Test if the provided score is lower than the best score found so far

        Arguments:
            score: the score to test

        Returns:
            res : is the provided score lower than the best score so far ?
        """
        return self.best_score is None or score < self.best_score

    def higher_is_better(self, score: float) -> bool:
        """
        Test if the provided score is higher than the best score found so far

        Arguments:
            score: the score to test

        Returns:
            res : is the provided score higher than the best score so far ?
        """
        return self.best_score is None or score > self.best_score

    def update(self, score: float) -> bool:
        """
        If the provided score is better than the best score registered so far,
        saves the model's parameters on disk as a pytorch tensor

        Arguments:
            score: the new score to consider

        Returns:
            res: whether or not the provided score is better than the best score
                 registered so far
        """
        if self.is_better(score):
            self.model.eval()

            torch.save(
                self.model.state_dict(), os.path.join(self.savepath, "best_model.pt")
            )

            # torch.onnx.export(
            #     self.model,
            #     dummy_input,
            #     os.path.join(self.savepath, "best_model.onnx"),
            #     verbose=False,
            #     input_names=["input"],
            #     output_names=["output"],
            #     dynamic_axes={
            #         "input": {0: "batch"},
            #         "output": {0: "batch"},
            #     },
            # )

            self.best_score = score
            return True
        return False


def generate_unique_logpath(logdir: str, raw_run_name: str) -> str:
    """
    Generate a unique directory name and create it if necessary

    Arguments:
        logdir: the prefix directory
        raw_run_name: the base name

    Returns:
        log_path: a non-existent path like logdir/raw_run_name_xxxx
                  where xxxx is an int
    """
    i = 0
    while True:
        run_name = raw_run_name + "_" + str(i)
        log_path = os.path.join(logdir, run_name)
        if not os.path.isdir(log_path):
            os.makedirs(log_path)
            return log_path
        i = i + 1