Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/models/vision_transformer.py: 32%
53 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2025 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# External imports
24import torch
25import torch.nn as nn
27# Local imports
28import torchcvnn.nn as c_nn
31def vit_t(
32 patch_embedder: nn.Module,
33 device: torch.device = None,
34 dtype: torch.dtype = torch.complex64,
35) -> nn.Module:
36 """
37 Builds a ViT tiny model.
39 Args:
40 patch_embedder: PatchEmbedder instance.
41 device: Device to use.
42 dtype: Data type to use.
44 The patch_embedder is responsible for computing the embedding of the patch
45 as well as adding the positional encoding if required.
47 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number
48 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer.
50 """
51 factory_kwargs = {"device": device, "dtype": dtype}
52 num_layers = 12
53 num_heads = 3
54 hidden_dim = 192
55 mlp_dim = 4 * 192
56 dropout = 0.0
57 attention_dropout = 0.0
58 norm_layer = c_nn.RMSNorm
60 return c_nn.ViT(
61 patch_embedder,
62 num_layers,
63 num_heads,
64 hidden_dim,
65 mlp_dim,
66 dropout=dropout,
67 attention_dropout=attention_dropout,
68 norm_layer=norm_layer,
69 **factory_kwargs
70 )
73def vit_s(
74 patch_embedder: nn.Module,
75 device: torch.device = None,
76 dtype: torch.dtype = torch.complex64,
77) -> nn.Module:
78 """
79 Builds a ViT small model.
81 Args:
82 patch_embedder: PatchEmbedder instance.
83 device: Device to use.
84 dtype: Data type to use.
86 The patch_embedder is responsible for computing the embedding of the patch
87 as well as adding the positional encoding if required.
89 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number
90 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer.
92 """
93 factory_kwargs = {"device": device, "dtype": dtype}
94 num_layers = 12
95 num_heads = 6
96 hidden_dim = 384
97 mlp_dim = 4 * 384
98 dropout = 0.0
99 attention_dropout = 0.0
100 norm_layer = c_nn.RMSNorm
102 return c_nn.ViT(
103 patch_embedder,
104 num_layers,
105 num_heads,
106 hidden_dim,
107 mlp_dim,
108 dropout=dropout,
109 attention_dropout=attention_dropout,
110 norm_layer=norm_layer,
111 **factory_kwargs
112 )
115def vit_b(
116 patch_embedder: nn.Module,
117 device: torch.device = None,
118 dtype: torch.dtype = torch.complex64,
119) -> nn.Module:
120 """
121 Builds a ViT base model.
123 Args:
124 patch_embedder: PatchEmbedder instance.
125 device: Device to use.
126 dtype: Data type to use.
128 The patch_embedder is responsible for computing the embedding of the patch
129 as well as adding the positional encoding if required.
131 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number
132 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer.
134 """
135 factory_kwargs = {"device": device, "dtype": dtype}
136 num_layers = 12
137 num_heads = 12
138 hidden_dim = 768
139 mlp_dim = 3072
140 dropout = 0.0
141 attention_dropout = 0.0
142 norm_layer = c_nn.RMSNorm
144 return c_nn.ViT(
145 patch_embedder,
146 num_layers,
147 num_heads,
148 hidden_dim,
149 mlp_dim,
150 dropout=dropout,
151 attention_dropout=attention_dropout,
152 norm_layer=norm_layer,
153 **factory_kwargs
154 )
157def vit_l(
158 patch_embedder: nn.Module,
159 device: torch.device = None,
160 dtype: torch.dtype = torch.complex64,
161) -> nn.Module:
162 """
163 Builds a ViT large model.
165 Args:
166 patch_embedder: PatchEmbedder instance.
167 device: Device to use.
168 dtype: Data type to use.
170 The patch_embedder is responsible for computing the embedding of the patch
171 as well as adding the positional encoding if required.
173 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number
174 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer.
176 """
177 factory_kwargs = {"device": device, "dtype": dtype}
178 num_layers = 24
179 num_heads = 16
180 hidden_dim = 1024
181 mlp_dim = 4096
182 dropout = 0.0
183 attention_dropout = 0.0
184 norm_layer = c_nn.RMSNorm
186 return c_nn.ViT(
187 patch_embedder,
188 num_layers,
189 num_heads,
190 hidden_dim,
191 mlp_dim,
192 dropout=dropout,
193 attention_dropout=attention_dropout,
194 norm_layer=norm_layer,
195 **factory_kwargs
196 )
199def vit_h(
200 patch_embedder: nn.Module,
201 device: torch.device = None,
202 dtype: torch.dtype = torch.complex64,
203) -> nn.Module:
204 """
205 Builds a ViT huge model.
207 Args:
208 patch_embedder: PatchEmbedder instance.
209 device: Device to use.
210 dtype: Data type to use.
212 The patch_embedder is responsible for computing the embedding of the patch
213 as well as adding the positional encoding if required.
215 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number
216 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer.
218 """
219 factory_kwargs = {"device": device, "dtype": dtype}
220 num_layers = 32
221 num_heads = 16
222 hidden_dim = 1280
223 mlp_dim = 5120
224 dropout = 0.0
225 attention_dropout = 0.0
226 norm_layer = c_nn.RMSNorm
228 return c_nn.ViT(
229 patch_embedder,
230 num_layers,
231 num_heads,
232 hidden_dim,
233 mlp_dim,
234 dropout=dropout,
235 attention_dropout=attention_dropout,
236 norm_layer=norm_layer,
237 **factory_kwargs
238 )