Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/vit.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Callable 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29 

30# Local imports 

31from .normalization import LayerNorm 

32from .activation import modReLU, MultiheadAttention 

33from .dropout import Dropout 

34 

35 

36class ViTLayer(nn.Module): 

37 

38 def __init__( 

39 self, 

40 num_heads: int, 

41 hidden_dim: int, 

42 mlp_dim: int, 

43 dropout: float = 0.0, 

44 attention_dropout: float = 0.0, 

45 norm_layer: Callable[..., nn.Module] = LayerNorm, 

46 device: torch.device = None, 

47 dtype: torch.dtype = torch.complex64, 

48 ) -> None: 

49 """ 

50 The ViT layer cascades a multi-head attention block with a feed-forward network. 

51 

52 Args: 

53 num_heads: Number of heads in the multi-head attention block. 

54 hidden_dim: Hidden dimension of the transformer. 

55 mlp_dim: Hidden dimension of the feed-forward network. 

56 dropout: Dropout rate (default: 0.0). 

57 attention_dropout: Dropout rate in the attention block (default: 0.0). 

58 norm_layer: Normalization layer (default :py:class:`LayerNorm`). 

59 

60 .. math:: 

61 

62 x & = x + \\text{attn}(\\text{norm1}(x))\\\\ 

63 x & = x + \\text{ffn}(\\text{norm2}(x)) 

64 

65 The FFN block is a two-layer MLP with a modReLU activation function. 

66 

67 """ 

68 super(ViTLayer, self).__init__() 

69 

70 factory_kwargs = {"device": device, "dtype": dtype} 

71 

72 self.norm1 = norm_layer(hidden_dim, **factory_kwargs) 

73 self.attn = MultiheadAttention( 

74 embed_dim=hidden_dim, 

75 dropout=attention_dropout, 

76 num_heads=num_heads, 

77 batch_first=True, 

78 **factory_kwargs 

79 ) 

80 self.dropout = Dropout(dropout) 

81 self.norm2 = norm_layer(hidden_dim) 

82 self.ffn = nn.Sequential( 

83 nn.Linear(hidden_dim, mlp_dim, **factory_kwargs), 

84 norm_layer(mlp_dim), 

85 modReLU(), 

86 Dropout(dropout), 

87 nn.Linear(mlp_dim, hidden_dim, **factory_kwargs), 

88 Dropout(dropout), 

89 ) 

90 

91 def forward(self, x: torch.Tensor) -> torch.Tensor: 

92 """ 

93 Performs the forward pass through the layer using pre-normalization. 

94 

95 Args: 

96 x: Input tensor of shape (B, seq_len, hidden_dim) 

97 """ 

98 norm_x = self.norm1(x) 

99 x = x + self.dropout(self.attn(norm_x, norm_x, norm_x, need_weights=False)[0]) 

100 x = x + self.ffn(self.norm2(x)) 

101 

102 return x 

103 

104 

105class ViT(nn.Module): 

106 

107 def __init__( 

108 self, 

109 patch_embedder: nn.Module, 

110 num_layers: int, 

111 num_heads: int, 

112 hidden_dim: int, 

113 mlp_dim: int, 

114 dropout: float = 0.0, 

115 attention_dropout: float = 0.0, 

116 norm_layer: Callable[..., nn.Module] = LayerNorm, 

117 device: torch.device = None, 

118 dtype: torch.dtype = torch.complex64, 

119 ): 

120 """ 

121 Vision Transformer model. This implementation does not contain any head. 

122 

123 For classification, you can for example compute a global average of the output embeddings : 

124 

125 .. code-block:: python 

126 

127 backbone = c_nn.ViT( 

128 embedder, 

129 num_layers, 

130 num_heads, 

131 hidden_dim, 

132 mlp_dim, 

133 dropout=dropout, 

134 attention_dropout=attention_dropout, 

135 norm_layer=norm_layer, 

136 ) 

137 

138 # A Linear decoding head to project on the logits 

139 head = nn.Sequential( 

140 nn.Linear(hidden_dim, 10, dtype=torch.complex64), c_nn.Mod() 

141 ) 

142 

143 x = torch.randn(B C, H, W) 

144 features = backbone(x) # B, num_patches, hidden_dim 

145 

146 # Global average pooling of the patches encoding 

147 mean_features = features.mean(dim=1) # B, hidden_dim 

148 

149 head(mean_features) 

150 

151 

152 

153 Args: 

154 patch_embedder: PatchEmbedder instance. 

155 num_layers: Number of layers in the transformer. 

156 num_heads: Number of heads in the multi-head attention block. 

157 hidden_dim: Hidden dimension of the transformer. 

158 mlp_dim: Hidden dimension of the feed-forward network. 

159 dropout: Dropout rate (default: 0.0). 

160 attention_dropout: Dropout rate in the attention block (default: 0.0). 

161 norm_layer: Normalization layer (default :py:class:`LayerNorm`). 

162 """ 

163 super(ViT, self).__init__() 

164 

165 factory_kwargs = {"device": device, "dtype": dtype} 

166 

167 self.patch_embedder = patch_embedder 

168 self.dropout = Dropout(dropout) 

169 self.layers = nn.ModuleList([]) 

170 for _ in range(num_layers): 

171 self.layers.append( 

172 ViTLayer( 

173 num_heads=num_heads, 

174 hidden_dim=hidden_dim, 

175 mlp_dim=mlp_dim, 

176 dropout=dropout, 

177 attention_dropout=attention_dropout, 

178 norm_layer=norm_layer, 

179 **factory_kwargs 

180 ) 

181 ) 

182 self.layers = nn.Sequential(*self.layers) 

183 self.norm = norm_layer(hidden_dim, **factory_kwargs) 

184 

185 def forward(self, x): 

186 # x : (B, C, H, W) 

187 embedding = self.patch_embedder(x) # (B, embed_dim, num_patch_H, num_patch_W) 

188 

189 # Transpose to (B, "seq_len"=num_patches, embed_dim) 

190 embedding = embedding.flatten(2).transpose(1, 2).contiguous() 

191 

192 out = self.layers(self.dropout(embedding)) 

193 

194 out = self.norm(out) 

195 

196 # Transpose to batch_first 

197 return out