Classification of MNIST digit from their Fourier representation¶
In this tutorial, we show how to create a complex valued neural network and to optimize it. As a matter of illustration, suppose you want to classify the MNIST digits from their Fourier space representation.
Loading the complex valued MNIST dataset¶
Using torchvision.datasets.MNIST
and torch.fft.fft()
, we can easily transform the MNIST dataset into its Fourier space representation :
dataset = torchvision.datasets.MNIST(
root="./data",
train=True,
download=True,
transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)
Implementing the complex valued convolutional neural network¶
To implement a complex valued CNN
cdtype = torch.complex64
conv_model = nn.Sequential(
*conv_block(1, 16, cdtype),
*conv_block(16, 16, cdtype),
*conv_block(16, 32, cdtype),
*conv_block(32, 32, cdtype),
nn.Flatten(),
)
with conv_block()
defined as :
def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
"""
Builds a basic building block of
`Conv2d`-`Cardioid`-`Conv2d`-`Cardioid`-`AvgPool2d`
Arguments:
in_c : the number of input channels
out_c : the number of output channels
cdtype : the dtype of complex values (expected to be torch.complex64 or torch.complex32)
"""
return [
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
]
And then we use a standard optimizer, training loop, cross entropy loss function, etc…
The full code is available to download mnist.py
and completly given below. To run the code, you also need the utils.py
file which provides some utilitary functions. Finally, additional dependencies other than torchcvnn are needed :
python3 -m pip install tochvision tqdm
If you run that script, the expected output should be :
Logging to ./logs/CMNIST_0
>> Training
100%|██████| 844/844 [00:17<00:00, 48.61it/s]
>> Testing
[Step 0] Train : CE 0.20 Acc 0.94 | Valid : CE 0.08 Acc 0.97 | Test : CE 0.06 Acc 0.98[>> BETTER <<]
>> Training
100%|██████| 844/844 [00:16<00:00, 51.69it/s]
>> Testing
[Step 1] Train : CE 0.06 Acc 0.98 | Valid : CE 0.06 Acc 0.98 | Test : CE 0.05 Acc 0.98[>> BETTER <<]
>> Training
100%|██████| 844/844 [00:15<00:00, 53.47it/s]
>> Testing
[Step 2] Train : CE 0.04 Acc 0.99 | Valid : CE 0.04 Acc 0.99 | Test : CE 0.04 Acc 0.99[>> BETTER <<]
[...]
Complete code¶
mnist.py¶
# MIT License
# Copyright (c) 2023 Jérémy Fix
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
# Example using complex valued neural networks to classify MNIST from the Fourier Transform of the digits.
Requires dependencies :
python3 -m pip install torchvision tqdm
"""
# Standard imports
import random
import sys
from typing import List
# External imports
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as v2_transforms
import torchcvnn.nn as c_nn
# Local imports
import utils
def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]:
"""
Builds a basic building block of
`Conv2d`-`Cardioid`-`Conv2d`-`Cardioid`-`AvgPool2d`
Arguments:
in_c : the number of input channels
out_c : the number of output channels
cdtype : the dtype of complex values (expected to be torch.complex64 or torch.complex32)
"""
return [
nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype),
c_nn.BatchNorm2d(out_c),
c_nn.Cardioid(),
c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
]
def train():
"""
Train function
Sample output :
```.bash
(venv) me@host:~$ python mnist.py
Logging to ./logs/CMNIST_0
>> Training
100%|██████| 844/844 [00:17<00:00, 48.61it/s]
>> Testing
[Step 0] Train : CE 0.20 Acc 0.94 | Valid : CE 0.08 Acc 0.97 | Test : CE 0.06 Acc 0.98[>> BETTER <<]
>> Training
100%|██████| 844/844 [00:16<00:00, 51.69it/s]
>> Testing
[Step 1] Train : CE 0.06 Acc 0.98 | Valid : CE 0.06 Acc 0.98 | Test : CE 0.05 Acc 0.98[>> BETTER <<]
>> Training
100%|██████| 844/844 [00:15<00:00, 53.47it/s]
>> Testing
[Step 2] Train : CE 0.04 Acc 0.99 | Valid : CE 0.04 Acc 0.99 | Test : CE 0.04 Acc 0.99[>> BETTER <<]
[...]
```
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
valid_ratio = 0.1
batch_size = 64
epochs = 10
cdtype = torch.complex64
# Dataloading
train_valid_dataset = torchvision.datasets.MNIST(
root="./data",
train=True,
download=True,
transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)
test_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]),
)
all_indices = list(range(len(train_valid_dataset)))
random.shuffle(all_indices)
split_idx = int(valid_ratio * len(train_valid_dataset))
valid_indices, train_indices = all_indices[:split_idx], all_indices[split_idx:]
# Train dataloader
train_dataset = torch.utils.data.Subset(train_valid_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# Valid dataloader
valid_dataset = torch.utils.data.Subset(train_valid_dataset, valid_indices)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, shuffle=False
)
# Test dataloader
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False
)
# Model
conv_model = nn.Sequential(
*conv_block(1, 16, cdtype),
*conv_block(16, 16, cdtype),
*conv_block(16, 32, cdtype),
*conv_block(32, 32, cdtype),
nn.Flatten(),
)
with torch.no_grad():
conv_model.eval()
dummy_input = torch.zeros((64, 1, 28, 28), dtype=cdtype, requires_grad=False)
out_conv = conv_model(dummy_input).view(64, -1)
lin_model = nn.Sequential(
nn.Linear(out_conv.shape[-1], 124, dtype=cdtype),
c_nn.Cardioid(),
nn.Linear(124, 10, dtype=cdtype),
c_nn.Mod(),
)
model = nn.Sequential(conv_model, lin_model)
model.to(device)
# Loss, optimizer, callbacks
f_loss = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=3e-4)
logpath = utils.generate_unique_logpath("./logs", "CMNIST")
print(f"Logging to {logpath}")
checkpoint = utils.ModelCheckpoint(model, logpath, 4, min_is_best=True)
# Training loop
for e in range(epochs):
print(">> Training")
train_loss, train_acc = utils.train_epoch(
model, train_loader, f_loss, optim, device
)
print(">> Testing")
valid_loss, valid_acc = utils.test_epoch(model, valid_loader, f_loss, device)
test_loss, test_acc = utils.test_epoch(model, test_loader, f_loss, device)
updated = checkpoint.update(valid_loss)
better_str = "[>> BETTER <<]" if updated else ""
print(
f"[Step {e}] Train : CE {train_loss:5.2f} Acc {train_acc:5.2f} | Valid : CE {valid_loss:5.2f} Acc {valid_acc:5.2f} | Test : CE {test_loss:5.2f} Acc {test_acc:5.2f}"
+ better_str
)
if __name__ == "__main__":
train()
utils.py¶
# coding: utf-8
# MIT License
# Copyright (c) 2023 Jeremy Fix
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Standard imports
import os
from typing import Tuple
# External imports
import torch
import torch.nn as nn
import tqdm
# import torch.onnx
def train_epoch(
model: nn.Module,
loader: torch.utils.data.DataLoader,
f_loss: nn.Module,
optim: torch.optim.Optimizer,
device: torch.device,
) -> Tuple[float, float]:
"""
Run the training loop for nsteps minibatches of the dataloader
Arguments:
model: the model to train
loader: an iterable dataloader
f_loss (nn.Module): the loss
optim : an optimizing algorithm
device: the device on which to run the code
Returns:
The averaged training loss
The averaged training accuracy
"""
model.train()
loss_avg = 0
acc_avg = 0
num_samples = 0
for inputs, outputs in tqdm.tqdm(loader):
inputs = inputs.to(device)
outputs = outputs.to(device)
# Forward propagate through the model
pred_outputs = model(inputs)
# Forward propagate through the loss
loss = f_loss(pred_outputs, outputs)
# Backward pass and update
optim.zero_grad()
loss.backward()
optim.step()
num_samples += inputs.shape[0]
# Denormalize the loss that is supposed to be averaged over the
# minibatch
loss_avg += inputs.shape[0] * loss.item()
pred_cls = pred_outputs.argmax(dim=-1)
acc_avg += (pred_cls == outputs).sum().item()
return loss_avg / num_samples, acc_avg / num_samples
def test_epoch(
model: nn.Module,
loader: torch.utils.data.DataLoader,
f_loss: nn.Module,
device: torch.device,
) -> Tuple[float, float]:
"""
Run the test loop for n_test_batches minibatches of the dataloader
Arguments:
model: the model to evaluate
loader: an iterable dataloader
f_loss: the loss
device: the device on which to run the code
Returns:
The averaged test loss
The averaged test accuracy
"""
model.eval()
loss_avg = 0
acc_avg = 0
num_samples = 0
with torch.no_grad():
for inputs, outputs in loader:
inputs = inputs.to(device)
outputs = outputs.to(device)
# Forward propagate through the model
pred_outputs = model(inputs)
# Forward propagate through the loss
loss = f_loss(pred_outputs, outputs)
loss_avg += inputs.shape[0] * loss.item()
pred_cls = pred_outputs.argmax(dim=-1)
acc_avg += (pred_cls == outputs).sum().item()
num_samples += inputs.shape[0]
return loss_avg / num_samples, acc_avg / num_samples
class ModelCheckpoint(object):
def __init__(
self,
model: torch.nn.Module,
savepath: str,
num_input_dims: int,
min_is_best: bool = True,
) -> None:
"""
Early stopping callback
Arguments:
model: the model to save
savepath: the location where to save the model's parameters
num_input_dims: the number of dimensions for the input tensor (required for onnx export)
min_is_best: whether the min metric or the max metric as the best
"""
self.model = model
self.savepath = savepath
self.num_input_dims = num_input_dims
self.best_score = None
if min_is_best:
self.is_better = self.lower_is_better
else:
self.is_better = self.higher_is_better
def lower_is_better(self, score: float) -> bool:
"""
Test if the provided score is lower than the best score found so far
Arguments:
score: the score to test
Returns:
res : is the provided score lower than the best score so far ?
"""
return self.best_score is None or score < self.best_score
def higher_is_better(self, score: float) -> bool:
"""
Test if the provided score is higher than the best score found so far
Arguments:
score: the score to test
Returns:
res : is the provided score higher than the best score so far ?
"""
return self.best_score is None or score > self.best_score
def update(self, score: float) -> bool:
"""
If the provided score is better than the best score registered so far,
saves the model's parameters on disk as a pytorch tensor
Arguments:
score: the new score to consider
Returns:
res: whether or not the provided score is better than the best score
registered so far
"""
if self.is_better(score):
self.model.eval()
torch.save(
self.model.state_dict(), os.path.join(self.savepath, "best_model.pt")
)
# torch.onnx.export(
# self.model,
# dummy_input,
# os.path.join(self.savepath, "best_model.onnx"),
# verbose=False,
# input_names=["input"],
# output_names=["output"],
# dynamic_axes={
# "input": {0: "batch"},
# "output": {0: "batch"},
# },
# )
self.best_score = score
return True
return False
def generate_unique_logpath(logdir: str, raw_run_name: str) -> str:
"""
Generate a unique directory name and create it if necessary
Arguments:
logdir: the prefix directory
raw_run_name: the base name
Returns:
log_path: a non-existent path like logdir/raw_run_name_xxxx
where xxxx is an int
"""
i = 0
while True:
run_name = raw_run_name + "_" + str(i)
log_path = os.path.join(logdir, run_name)
if not os.path.isdir(log_path):
os.makedirs(log_path)
return log_path
i = i + 1