Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/models/vision_transformer.py: 32%

53 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2025 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# External imports 

24import torch 

25import torch.nn as nn 

26 

27# Local imports 

28import torchcvnn.nn as c_nn 

29 

30 

31def vit_t( 

32 patch_embedder: nn.Module, 

33 device: torch.device = None, 

34 dtype: torch.dtype = torch.complex64, 

35) -> nn.Module: 

36 """ 

37 Builds a ViT tiny model. 

38 

39 Args: 

40 patch_embedder: PatchEmbedder instance. 

41 device: Device to use. 

42 dtype: Data type to use. 

43 

44 The patch_embedder is responsible for computing the embedding of the patch 

45 as well as adding the positional encoding if required. 

46 

47 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number 

48 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer. 

49 

50 """ 

51 factory_kwargs = {"device": device, "dtype": dtype} 

52 num_layers = 12 

53 num_heads = 3 

54 hidden_dim = 192 

55 mlp_dim = 4 * 192 

56 dropout = 0.0 

57 attention_dropout = 0.0 

58 norm_layer = c_nn.RMSNorm 

59 

60 return c_nn.ViT( 

61 patch_embedder, 

62 num_layers, 

63 num_heads, 

64 hidden_dim, 

65 mlp_dim, 

66 dropout=dropout, 

67 attention_dropout=attention_dropout, 

68 norm_layer=norm_layer, 

69 **factory_kwargs 

70 ) 

71 

72 

73def vit_s( 

74 patch_embedder: nn.Module, 

75 device: torch.device = None, 

76 dtype: torch.dtype = torch.complex64, 

77) -> nn.Module: 

78 """ 

79 Builds a ViT small model. 

80 

81 Args: 

82 patch_embedder: PatchEmbedder instance. 

83 device: Device to use. 

84 dtype: Data type to use. 

85 

86 The patch_embedder is responsible for computing the embedding of the patch 

87 as well as adding the positional encoding if required. 

88 

89 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number 

90 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer. 

91 

92 """ 

93 factory_kwargs = {"device": device, "dtype": dtype} 

94 num_layers = 12 

95 num_heads = 6 

96 hidden_dim = 384 

97 mlp_dim = 4 * 384 

98 dropout = 0.0 

99 attention_dropout = 0.0 

100 norm_layer = c_nn.RMSNorm 

101 

102 return c_nn.ViT( 

103 patch_embedder, 

104 num_layers, 

105 num_heads, 

106 hidden_dim, 

107 mlp_dim, 

108 dropout=dropout, 

109 attention_dropout=attention_dropout, 

110 norm_layer=norm_layer, 

111 **factory_kwargs 

112 ) 

113 

114 

115def vit_b( 

116 patch_embedder: nn.Module, 

117 device: torch.device = None, 

118 dtype: torch.dtype = torch.complex64, 

119) -> nn.Module: 

120 """ 

121 Builds a ViT base model. 

122 

123 Args: 

124 patch_embedder: PatchEmbedder instance. 

125 device: Device to use. 

126 dtype: Data type to use. 

127 

128 The patch_embedder is responsible for computing the embedding of the patch 

129 as well as adding the positional encoding if required. 

130 

131 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number 

132 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer. 

133 

134 """ 

135 factory_kwargs = {"device": device, "dtype": dtype} 

136 num_layers = 12 

137 num_heads = 12 

138 hidden_dim = 768 

139 mlp_dim = 3072 

140 dropout = 0.0 

141 attention_dropout = 0.0 

142 norm_layer = c_nn.RMSNorm 

143 

144 return c_nn.ViT( 

145 patch_embedder, 

146 num_layers, 

147 num_heads, 

148 hidden_dim, 

149 mlp_dim, 

150 dropout=dropout, 

151 attention_dropout=attention_dropout, 

152 norm_layer=norm_layer, 

153 **factory_kwargs 

154 ) 

155 

156 

157def vit_l( 

158 patch_embedder: nn.Module, 

159 device: torch.device = None, 

160 dtype: torch.dtype = torch.complex64, 

161) -> nn.Module: 

162 """ 

163 Builds a ViT large model. 

164 

165 Args: 

166 patch_embedder: PatchEmbedder instance. 

167 device: Device to use. 

168 dtype: Data type to use. 

169 

170 The patch_embedder is responsible for computing the embedding of the patch 

171 as well as adding the positional encoding if required. 

172 

173 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number 

174 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer. 

175 

176 """ 

177 factory_kwargs = {"device": device, "dtype": dtype} 

178 num_layers = 24 

179 num_heads = 16 

180 hidden_dim = 1024 

181 mlp_dim = 4096 

182 dropout = 0.0 

183 attention_dropout = 0.0 

184 norm_layer = c_nn.RMSNorm 

185 

186 return c_nn.ViT( 

187 patch_embedder, 

188 num_layers, 

189 num_heads, 

190 hidden_dim, 

191 mlp_dim, 

192 dropout=dropout, 

193 attention_dropout=attention_dropout, 

194 norm_layer=norm_layer, 

195 **factory_kwargs 

196 ) 

197 

198 

199def vit_h( 

200 patch_embedder: nn.Module, 

201 device: torch.device = None, 

202 dtype: torch.dtype = torch.complex64, 

203) -> nn.Module: 

204 """ 

205 Builds a ViT huge model. 

206 

207 Args: 

208 patch_embedder: PatchEmbedder instance. 

209 device: Device to use. 

210 dtype: Data type to use. 

211 

212 The patch_embedder is responsible for computing the embedding of the patch 

213 as well as adding the positional encoding if required. 

214 

215 It maps from :math:`(B, C, H, W)` to :math:`(B, hidden\_dim, N_h, N_w)` where :math:`N_h \times N_w` is the number 

216 of patches in the image. The embedding dimension must match the expected hidden dimension of the transformer. 

217 

218 """ 

219 factory_kwargs = {"device": device, "dtype": dtype} 

220 num_layers = 32 

221 num_heads = 16 

222 hidden_dim = 1280 

223 mlp_dim = 5120 

224 dropout = 0.0 

225 attention_dropout = 0.0 

226 norm_layer = c_nn.RMSNorm 

227 

228 return c_nn.ViT( 

229 patch_embedder, 

230 num_layers, 

231 num_heads, 

232 hidden_dim, 

233 mlp_dim, 

234 dropout=dropout, 

235 attention_dropout=attention_dropout, 

236 norm_layer=norm_layer, 

237 **factory_kwargs 

238 )