Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / vit.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 06:48 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Callable 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29 

30# Local imports 

31from .normalization import LayerNorm 

32from .activation import MultiheadAttention, CGELU 

33from .dropout import Dropout 

34 

35 

36class ViTLayer(nn.Module): 

37 

38 def __init__( 

39 self, 

40 num_heads: int, 

41 embed_dim: int, 

42 hidden_dim: int, 

43 dropout: float = 0.0, 

44 attention_dropout: float = 0.0, 

45 norm_layer: Callable[..., nn.Module] = LayerNorm, 

46 activation_fn: Callable[[], nn.Module] = CGELU, 

47 device: torch.device = None, 

48 dtype: torch.dtype = torch.complex64, 

49 ) -> None: 

50 """ 

51 The ViT layer cascades a multi-head attention block with a feed-forward network. 

52 

53 Args: 

54 num_heads: Number of heads in the multi-head attention block. 

55 embed_dim: Hidden dimension of the transformer. 

56 hidden_dim: Hidden dimension of the feed-forward network. 

57 dropout: Dropout rate (default: 0.0). 

58 attention_dropout: Dropout rate in the attention block (default: 0.0). 

59 norm_layer: Normalization layer (default :py:class:`LayerNorm`). 

60 

61 .. math:: 

62 

63 x & = x + \\text{attn}(\\text{norm1}(x))\\\\ 

64 x & = x + \\text{ffn}(\\text{norm2}(x)) 

65 

66 The FFN block is a two-layer MLP with a modReLU activation function. 

67 

68 """ 

69 super(ViTLayer, self).__init__() 

70 

71 factory_kwargs = {"device": device, "dtype": dtype} 

72 

73 self.norm1 = norm_layer(embed_dim) 

74 self.attn = MultiheadAttention( 

75 embed_dim=embed_dim, 

76 num_heads=num_heads, 

77 dropout=attention_dropout, 

78 norm_layer=norm_layer, 

79 batch_first=True, 

80 **factory_kwargs 

81 ) 

82 self.dropout1 = Dropout(dropout) 

83 

84 self.ffn = nn.Sequential( 

85 norm_layer(embed_dim), 

86 nn.Linear(embed_dim, hidden_dim, **factory_kwargs), 

87 activation_fn(), 

88 Dropout(dropout), 

89 norm_layer(hidden_dim), 

90 nn.Linear(hidden_dim, embed_dim, **factory_kwargs), 

91 activation_fn(), 

92 Dropout(dropout) 

93 ) 

94 

95 def forward(self, x: torch.Tensor) -> torch.Tensor: 

96 """ 

97 Performs the forward pass through the layer using pre-normalization. 

98 

99 Args: 

100 x: Input tensor of shape (B, seq_len, hidden_dim) 

101 """ 

102 norm_x = self.norm1(x) 

103 x = x + self.dropout1(self.attn(norm_x, norm_x, norm_x, need_weights=False)) 

104 x = x + self.ffn(x) 

105 

106 return x 

107 

108 

109class ViT(nn.Module): 

110 

111 def __init__( 

112 self, 

113 patch_embedder: nn.Module, 

114 num_layers: int, 

115 num_heads: int, 

116 embed_dim: int, 

117 hidden_dim: int, 

118 dropout: float = 0.0, 

119 attention_dropout: float = 0.0, 

120 norm_layer: Callable[..., nn.Module] = LayerNorm, 

121 device: torch.device = None, 

122 dtype: torch.dtype = torch.complex64, 

123 ): 

124 """ 

125 Vision Transformer model. This implementation does not contain any head. 

126 

127 For classification, you can for example compute a global average of the output embeddings : 

128 

129 .. code-block:: python 

130 

131 backbone = c_nn.ViT( 

132 embedder, 

133 num_layers, 

134 num_heads, 

135 embed_dim, 

136 hidden_dim, 

137 dropout=dropout, 

138 attention_dropout=attention_dropout, 

139 norm_layer=norm_layer, 

140 ) 

141 

142 # A Linear decoding head to project on the logits 

143 head = nn.Sequential( 

144 nn.Linear(embed_dim, 10, dtype=torch.complex64), c_nn.Mod() 

145 ) 

146 

147 x = torch.randn(B C, H, W) 

148 features = backbone(x) # B, num_patches, embed_dim 

149 

150 # Global average pooling of the patches encoding 

151 mean_features = features.mean(dim=1) # B, embed_dim 

152 

153 head(mean_features) 

154 

155 Args: 

156 patch_embedder: PatchEmbedder instance. 

157 num_layers: Number of layers in the transformer. 

158 num_heads: Number of heads in the multi-head attention block. 

159 embed_dim: Hidden dimension of the transformer. 

160 hidden_dim: Hidden dimension of the feed-forward network. 

161 dropout: Dropout rate (default: 0.0). 

162 attention_dropout: Dropout rate in the attention block (default: 0.0). 

163 norm_layer: Normalization layer (default :py:class:`LayerNorm`). 

164 """ 

165 super(ViT, self).__init__() 

166 

167 factory_kwargs = {"device": device, "dtype": dtype} 

168 

169 self.patch_embedder = patch_embedder 

170 self.dropout = Dropout(dropout) 

171 self.layers = nn.ModuleList([]) 

172 for _ in range(num_layers): 

173 self.layers.append( 

174 ViTLayer( 

175 num_heads=num_heads, 

176 embed_dim=embed_dim, 

177 hidden_dim=hidden_dim, 

178 dropout=dropout, 

179 attention_dropout=attention_dropout, 

180 norm_layer=norm_layer, 

181 **factory_kwargs 

182 ) 

183 ) 

184 self.layers = nn.Sequential(*self.layers) 

185 self.norm = norm_layer(embed_dim, **factory_kwargs) 

186 

187 def forward(self, x): 

188 # x : (B, C, H, W) 

189 embedding = self.patch_embedder(x) # (B, src_len, embed_dim) 

190 

191 out = self.layers(self.dropout(embedding)) 

192 

193 out = self.norm(out) 

194 

195 return out