Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/vit.py: 100%
37 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Callable
26# External imports
27import torch
28import torch.nn as nn
30# Local imports
31from .normalization import LayerNorm
32from .activation import modReLU, MultiheadAttention
33from .dropout import Dropout
36class ViTLayer(nn.Module):
38 def __init__(
39 self,
40 num_heads: int,
41 hidden_dim: int,
42 mlp_dim: int,
43 dropout: float = 0.0,
44 attention_dropout: float = 0.0,
45 norm_layer: Callable[..., nn.Module] = LayerNorm,
46 device: torch.device = None,
47 dtype: torch.dtype = torch.complex64,
48 ) -> None:
49 """
50 The ViT layer cascades a multi-head attention block with a feed-forward network.
52 Args:
53 num_heads: Number of heads in the multi-head attention block.
54 hidden_dim: Hidden dimension of the transformer.
55 mlp_dim: Hidden dimension of the feed-forward network.
56 dropout: Dropout rate (default: 0.0).
57 attention_dropout: Dropout rate in the attention block (default: 0.0).
58 norm_layer: Normalization layer (default :py:class:`LayerNorm`).
60 .. math::
62 x & = x + \\text{attn}(\\text{norm1}(x))\\\\
63 x & = x + \\text{ffn}(\\text{norm2}(x))
65 The FFN block is a two-layer MLP with a modReLU activation function.
67 """
68 super(ViTLayer, self).__init__()
70 factory_kwargs = {"device": device, "dtype": dtype}
72 self.norm1 = norm_layer(hidden_dim, **factory_kwargs)
73 self.attn = MultiheadAttention(
74 embed_dim=hidden_dim,
75 dropout=attention_dropout,
76 num_heads=num_heads,
77 batch_first=True,
78 **factory_kwargs
79 )
80 self.dropout = Dropout(dropout)
81 self.norm2 = norm_layer(hidden_dim)
82 self.ffn = nn.Sequential(
83 nn.Linear(hidden_dim, mlp_dim, **factory_kwargs),
84 norm_layer(mlp_dim),
85 modReLU(),
86 Dropout(dropout),
87 nn.Linear(mlp_dim, hidden_dim, **factory_kwargs),
88 Dropout(dropout),
89 )
91 def forward(self, x: torch.Tensor) -> torch.Tensor:
92 """
93 Performs the forward pass through the layer using pre-normalization.
95 Args:
96 x: Input tensor of shape (B, seq_len, hidden_dim)
97 """
98 norm_x = self.norm1(x)
99 x = x + self.dropout(self.attn(norm_x, norm_x, norm_x, need_weights=False)[0])
100 x = x + self.ffn(self.norm2(x))
102 return x
105class ViT(nn.Module):
107 def __init__(
108 self,
109 patch_embedder: nn.Module,
110 num_layers: int,
111 num_heads: int,
112 hidden_dim: int,
113 mlp_dim: int,
114 dropout: float = 0.0,
115 attention_dropout: float = 0.0,
116 norm_layer: Callable[..., nn.Module] = LayerNorm,
117 device: torch.device = None,
118 dtype: torch.dtype = torch.complex64,
119 ):
120 """
121 Vision Transformer model. This implementation does not contain any head.
123 For classification, you can for example compute a global average of the output embeddings :
125 .. code-block:: python
127 backbone = c_nn.ViT(
128 embedder,
129 num_layers,
130 num_heads,
131 hidden_dim,
132 mlp_dim,
133 dropout=dropout,
134 attention_dropout=attention_dropout,
135 norm_layer=norm_layer,
136 )
138 # A Linear decoding head to project on the logits
139 head = nn.Sequential(
140 nn.Linear(hidden_dim, 10, dtype=torch.complex64), c_nn.Mod()
141 )
143 x = torch.randn(B C, H, W)
144 features = backbone(x) # B, num_patches, hidden_dim
146 # Global average pooling of the patches encoding
147 mean_features = features.mean(dim=1) # B, hidden_dim
149 head(mean_features)
153 Args:
154 patch_embedder: PatchEmbedder instance.
155 num_layers: Number of layers in the transformer.
156 num_heads: Number of heads in the multi-head attention block.
157 hidden_dim: Hidden dimension of the transformer.
158 mlp_dim: Hidden dimension of the feed-forward network.
159 dropout: Dropout rate (default: 0.0).
160 attention_dropout: Dropout rate in the attention block (default: 0.0).
161 norm_layer: Normalization layer (default :py:class:`LayerNorm`).
162 """
163 super(ViT, self).__init__()
165 factory_kwargs = {"device": device, "dtype": dtype}
167 self.patch_embedder = patch_embedder
168 self.dropout = Dropout(dropout)
169 self.layers = nn.ModuleList([])
170 for _ in range(num_layers):
171 self.layers.append(
172 ViTLayer(
173 num_heads=num_heads,
174 hidden_dim=hidden_dim,
175 mlp_dim=mlp_dim,
176 dropout=dropout,
177 attention_dropout=attention_dropout,
178 norm_layer=norm_layer,
179 **factory_kwargs
180 )
181 )
182 self.layers = nn.Sequential(*self.layers)
183 self.norm = norm_layer(hidden_dim, **factory_kwargs)
185 def forward(self, x):
186 # x : (B, C, H, W)
187 embedding = self.patch_embedder(x) # (B, embed_dim, num_patch_H, num_patch_W)
189 # Transpose to (B, "seq_len"=num_patches, embed_dim)
190 embedding = embedding.flatten(2).transpose(1, 2).contiguous()
192 out = self.layers(self.dropout(embedding))
194 out = self.norm(out)
196 # Transpose to batch_first
197 return out