Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / vit.py: 100%
35 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Callable
26# External imports
27import torch
28import torch.nn as nn
30# Local imports
31from .normalization import LayerNorm
32from .activation import MultiheadAttention, CGELU
33from .dropout import Dropout
36class ViTLayer(nn.Module):
38 def __init__(
39 self,
40 num_heads: int,
41 embed_dim: int,
42 hidden_dim: int,
43 dropout: float = 0.0,
44 attention_dropout: float = 0.0,
45 norm_layer: Callable[..., nn.Module] = LayerNorm,
46 activation_fn: Callable[[], nn.Module] = CGELU,
47 device: torch.device = None,
48 dtype: torch.dtype = torch.complex64,
49 ) -> None:
50 """
51 The ViT layer cascades a multi-head attention block with a feed-forward network.
53 Args:
54 num_heads: Number of heads in the multi-head attention block.
55 embed_dim: Hidden dimension of the transformer.
56 hidden_dim: Hidden dimension of the feed-forward network.
57 dropout: Dropout rate (default: 0.0).
58 attention_dropout: Dropout rate in the attention block (default: 0.0).
59 norm_layer: Normalization layer (default :py:class:`LayerNorm`).
61 .. math::
63 x & = x + \\text{attn}(\\text{norm1}(x))\\\\
64 x & = x + \\text{ffn}(\\text{norm2}(x))
66 The FFN block is a two-layer MLP with a modReLU activation function.
68 """
69 super(ViTLayer, self).__init__()
71 factory_kwargs = {"device": device, "dtype": dtype}
73 self.norm1 = norm_layer(embed_dim)
74 self.attn = MultiheadAttention(
75 embed_dim=embed_dim,
76 num_heads=num_heads,
77 dropout=attention_dropout,
78 norm_layer=norm_layer,
79 batch_first=True,
80 **factory_kwargs
81 )
82 self.dropout1 = Dropout(dropout)
84 self.ffn = nn.Sequential(
85 norm_layer(embed_dim),
86 nn.Linear(embed_dim, hidden_dim, **factory_kwargs),
87 activation_fn(),
88 Dropout(dropout),
89 norm_layer(hidden_dim),
90 nn.Linear(hidden_dim, embed_dim, **factory_kwargs),
91 activation_fn(),
92 Dropout(dropout)
93 )
95 def forward(self, x: torch.Tensor) -> torch.Tensor:
96 """
97 Performs the forward pass through the layer using pre-normalization.
99 Args:
100 x: Input tensor of shape (B, seq_len, hidden_dim)
101 """
102 norm_x = self.norm1(x)
103 x = x + self.dropout1(self.attn(norm_x, norm_x, norm_x, need_weights=False))
104 x = x + self.ffn(x)
106 return x
109class ViT(nn.Module):
111 def __init__(
112 self,
113 patch_embedder: nn.Module,
114 num_layers: int,
115 num_heads: int,
116 embed_dim: int,
117 hidden_dim: int,
118 dropout: float = 0.0,
119 attention_dropout: float = 0.0,
120 norm_layer: Callable[..., nn.Module] = LayerNorm,
121 device: torch.device = None,
122 dtype: torch.dtype = torch.complex64,
123 ):
124 """
125 Vision Transformer model. This implementation does not contain any head.
127 For classification, you can for example compute a global average of the output embeddings :
129 .. code-block:: python
131 backbone = c_nn.ViT(
132 embedder,
133 num_layers,
134 num_heads,
135 embed_dim,
136 hidden_dim,
137 dropout=dropout,
138 attention_dropout=attention_dropout,
139 norm_layer=norm_layer,
140 )
142 # A Linear decoding head to project on the logits
143 head = nn.Sequential(
144 nn.Linear(embed_dim, 10, dtype=torch.complex64), c_nn.Mod()
145 )
147 x = torch.randn(B C, H, W)
148 features = backbone(x) # B, num_patches, embed_dim
150 # Global average pooling of the patches encoding
151 mean_features = features.mean(dim=1) # B, embed_dim
153 head(mean_features)
155 Args:
156 patch_embedder: PatchEmbedder instance.
157 num_layers: Number of layers in the transformer.
158 num_heads: Number of heads in the multi-head attention block.
159 embed_dim: Hidden dimension of the transformer.
160 hidden_dim: Hidden dimension of the feed-forward network.
161 dropout: Dropout rate (default: 0.0).
162 attention_dropout: Dropout rate in the attention block (default: 0.0).
163 norm_layer: Normalization layer (default :py:class:`LayerNorm`).
164 """
165 super(ViT, self).__init__()
167 factory_kwargs = {"device": device, "dtype": dtype}
169 self.patch_embedder = patch_embedder
170 self.dropout = Dropout(dropout)
171 self.layers = nn.ModuleList([])
172 for _ in range(num_layers):
173 self.layers.append(
174 ViTLayer(
175 num_heads=num_heads,
176 embed_dim=embed_dim,
177 hidden_dim=hidden_dim,
178 dropout=dropout,
179 attention_dropout=attention_dropout,
180 norm_layer=norm_layer,
181 **factory_kwargs
182 )
183 )
184 self.layers = nn.Sequential(*self.layers)
185 self.norm = norm_layer(embed_dim, **factory_kwargs)
187 def forward(self, x):
188 # x : (B, C, H, W)
189 embedding = self.patch_embedder(x) # (B, src_len, embed_dim)
191 out = self.layers(self.dropout(embedding))
193 out = self.norm(out)
195 return out