Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / functional.py: 94%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 06:48 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional, Tuple 

25import math 

26 

27# External imports 

28import torch 

29from torch import Tensor 

30import torch.nn.functional as F 

31 

32 

33def dropout( 

34 z: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False 

35) -> Tensor: 

36 if p < 0.0 or p > 1.0: 

37 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 

38 mask = F.dropout(torch.ones(z.shape, device=z.device), p, training=training) 

39 return mask * z 

40 

41 

42def multi_head_attention_forward( 

43 query: Tensor, 

44 key: Tensor, 

45 value: Tensor, 

46 embed_dim_to_check: int, 

47 num_heads: int, 

48 q_norm: torch.nn.Module, 

49 k_norm: torch.nn.Module, 

50 in_proj_weight: Tensor, 

51 in_proj_bias: Tensor, 

52 dropout_p: float, 

53 out_proj: Optional[torch.nn.Module], 

54 training: bool = True, 

55 need_weights: bool = True, 

56 average_attn_weights: bool = True, 

57) -> Tuple[Tensor, Optional[Tensor]]: 

58 r"""Forward method for MultiHeadAttention. 

59 

60 This function is adapted from pytorch torch.nn.functional.multi_head_attention_forward 

61 

62 See :class:`torchcvnn.nn.MultiheadAttention` for details. 

63 

64 Args: 

65 query, key, value: map a query and a set of key-value pairs to an output. 

66 See "Attention Is All You Need" for more details. 

67 embed_dim_to_check: total dimension of the model. 

68 num_heads: parallel attention heads. 

69 in_proj_weight, in_proj_bias: input projection weight and bias. 

70 dropout_p: probability of an element to be zeroed. 

71 out_proj: layer for the output projection 

72 training: apply dropout if is ``True``. 

73 need_weights: output attn_output_weights. 

74 Default: `True` 

75 Note: `needs_weight` defaults to `True`, but should be set to `False` 

76 For best performance when attention weights are not needed. 

77 *Setting needs_weights to `True` 

78 leads to a significant performance degradation.* 

79 q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 

80 average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. 

81 Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect 

82 when ``need_weights=True.``. Default: True 

83 

84 Shape: 

85 Inputs: 

86 - query: :math:`(T, B, E)` where T is the target sequence length, B is the batch size, E is the embedding dimension. 

87 - key: :math:`(S, B, E)`, where S is the source sequence length, B is the batch size, E is the embedding dimension. 

88 - value: :math:`(S, B, E)` where S is the source sequence length, B is the batch size, E is the embedding dimension. 

89 

90 Outputs: 

91 - attn_output: :math:`(T, B, E)` where T is the target sequence length, B is the batch size, E is the embedding dimension. 

92 - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns 

93 attention weights averaged across heads of shape 

94 :math:`(B, T, S)`, where :math:`B` is the batch size, :math:`T` is the target sequence length, and 

95 :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(B, num_heads, T, S)`. 

96 """ 

97 

98 assert query.dim() == 3, "Expected batched tensors" 

99 

100 # set up shape vars 

101 tgt_len, bsz, embed_dim = query.shape 

102 src_len, _, _ = key.shape 

103 

104 assert ( 

105 embed_dim == embed_dim_to_check 

106 ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 

107 

108 if isinstance(embed_dim, torch.Tensor): 

109 # embed_dim can be a tensor when JIT tracing 

110 head_dim = embed_dim.div(num_heads, rounding_mode="trunc") 

111 else: 

112 head_dim = embed_dim // num_heads 

113 

114 assert ( 

115 num_heads * head_dim == embed_dim 

116 ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 

117 

118 assert ( 

119 key.shape == value.shape 

120 ), f"key shape {key.shape} does not match value shape {value.shape}" 

121 

122 # 

123 # compute in-projection 

124 # 

125 assert ( 

126 in_proj_weight is not None 

127 ), "in_proj_weight is None" 

128 q, k, v = F._in_projection_packed( 

129 query, key, value, in_proj_weight, in_proj_bias 

130 ) # (T, B, E), (S, B, E), (S, B, E) 

131 q = q_norm(q) 

132 k = k_norm(k) 

133 

134 # 

135 # reshape q, k, v for multihead attention and make them batch first 

136 # 

137 q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, tgt_len, head_dim 

138 k = k.view(src_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, src_len, head_dim 

139 v = v.view(src_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, src_len, head_dim 

140 

141 # adjust dropout probability 

142 if not training: 

143 dropout_p = 0.0 

144 

145 # 

146 # (deep breath) calculate attention and out projection 

147 # 

148 

149 # This is the "if need_weights" from the original pytorch code 

150 # We just adapt the case where the weights are needed 

151 # Indeed, since we are using specific implementations for computing 

152 # attention for the complex valued case, we cannot use the optimized versions 

153 # of the original pytorch code (flash attention or others) 

154 q_scaled = q * math.sqrt(1.0 / float(head_dim)) 

155 

156 # For equation (8) from (Eilers et al. 2023), 

157 # We need to conjugate the keys before computing the dot product 

158 k = k.conj() 

159 

160 # q_scaled # B * num_heads, tgt_len, head_dim 

161 # k.transpose(-2, -1) # B * num_heads, head_dim, src_len 

162 # bmm( X(B, n, m) , Y(B, m, p) ) = Z (B, n, p) 

163 attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) #[B * num_heads, tgt_len, src_len] 

164 

165 # And then take the real part of the result 

166 attn_output_weights = attn_output_weights.real 

167 

168 attn_output_weights = F.softmax(attn_output_weights, dim=-1) # Softmax over the src_len dimension 

169 if dropout_p > 0.0: 

170 attn_output_weights = dropout(attn_output_weights, p=dropout_p) 

171 

172 # attn_output_weights are real valued while v are complex valued 

173 # attn_output_weights : B * num_heads, tgt_len, src_len 

174 # v : B * num_heads, src_len, head_dim 

175 attn_output = torch.bmm(attn_output_weights.to(v.dtype), v) # B * num_heads, tgt_len, head_dim 

176 

177 attn_output = attn_output.view(bsz, num_heads, tgt_len, head_dim) #[B, num_heads, tgt_len, head_dim] 

178 attn_output = ( 

179 attn_output.transpose(1, 2).contiguous() # (B, tgt_len, num_heads, head_dim) 

180 .view(bsz * tgt_len, embed_dim) # (B * tgt_len, num_heads * head_dim) = (B * tgt_len, embed_dim) 

181 ) 

182 

183 if out_proj is not None: 

184 attn_output = out_proj(attn_output) 

185 

186 attn_output = ( 

187 attn_output.view(bsz, tgt_len, embed_dim) # B, seq_len, embed_dim 

188 .transpose(0, 1) # seq_len , B, embed_dim 

189 ) 

190 

191 # Early exist if we do not need the weights 

192 if not need_weights: 

193 return attn_output, None 

194 

195 # Perform the extra computation only if the weights are needed 

196 

197 # optionally average attention weights over heads 

198 attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # (B, num_heads, tgt_len, src_len) 

199 if average_attn_weights: 

200 attn_output_weights = attn_output_weights.mean(dim=1) # (B, tgt_len, src_len) 

201 

202 return attn_output, attn_output_weights