Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/functional.py: 57%

127 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional, Tuple 

25import math 

26 

27# External imports 

28import torch 

29from torch import Tensor 

30import torch.nn.functional as F 

31 

32 

33def dropout( 

34 z: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False 

35) -> Tensor: 

36 if p < 0.0 or p > 1.0: 

37 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 

38 mask = F.dropout(torch.ones(z.shape), p, training=training).to(z.device) 

39 return mask * z 

40 

41 

42def multi_head_attention_forward( 

43 query: Tensor, 

44 key: Tensor, 

45 value: Tensor, 

46 embed_dim_to_check: int, 

47 num_heads: int, 

48 in_proj_weight: Optional[Tensor], 

49 in_proj_bias: Optional[Tensor], 

50 bias_k: Optional[Tensor], 

51 bias_v: Optional[Tensor], 

52 add_zero_attn: bool, 

53 dropout_p: float, 

54 out_proj_weight: Tensor, 

55 out_proj_bias: Optional[Tensor], 

56 training: bool = True, 

57 key_padding_mask: Optional[Tensor] = None, 

58 need_weights: bool = True, 

59 attn_mask: Optional[Tensor] = None, 

60 use_separate_proj_weight: bool = False, 

61 q_proj_weight: Optional[Tensor] = None, 

62 k_proj_weight: Optional[Tensor] = None, 

63 v_proj_weight: Optional[Tensor] = None, 

64 static_k: Optional[Tensor] = None, 

65 static_v: Optional[Tensor] = None, 

66 average_attn_weights: bool = True, 

67 is_causal: bool = False, 

68) -> Tuple[Tensor, Optional[Tensor]]: 

69 r"""Forward method for MultiHeadAttention. 

70 

71 This function is adapted from pytorch torch.nn.functional.multi_head_attention_forward 

72 

73 See :class:`torchcvnn.nn.MultiheadAttention` for details. 

74 

75 Args: 

76 query, key, value: map a query and a set of key-value pairs to an output. 

77 See "Attention Is All You Need" for more details. 

78 embed_dim_to_check: total dimension of the model. 

79 num_heads: parallel attention heads. 

80 in_proj_weight, in_proj_bias: input projection weight and bias. 

81 bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 

82 add_zero_attn: add a new batch of zeros to the key and 

83 value sequences at dim=1. 

84 dropout_p: probability of an element to be zeroed. 

85 out_proj_weight, out_proj_bias: the output projection weight and bias. 

86 training: apply dropout if is ``True``. 

87 key_padding_mask: if provided, specified padding elements in the key will 

88 be ignored by the attention. This is an binary mask. When the value is True, 

89 the corresponding value on the attention layer will be filled with -inf. 

90 need_weights: output attn_output_weights. 

91 Default: `True` 

92 Note: `needs_weight` defaults to `True`, but should be set to `False` 

93 For best performance when attention weights are not needed. 

94 *Setting needs_weights to `True` 

95 leads to a significant performance degradation.* 

96 attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 

97 the batches while a 3D mask allows to specify a different mask for the entries of each batch. 

98 is_causal: If specified, applies a causal mask as attention mask, and ignores 

99 attn_mask for computing scaled dot product attention. 

100 Default: ``False``. 

101 .. warning:: 

102 is_causal is provides a hint that the attn_mask is the 

103 causal mask.Providing incorrect hints can result in 

104 incorrect execution, including forward and backward 

105 compatibility. 

106 use_separate_proj_weight: the function accept the proj. weights for query, key, 

107 and value in different forms. If false, in_proj_weight will be used, which is 

108 a combination of q_proj_weight, k_proj_weight, v_proj_weight. 

109 q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 

110 static_k, static_v: static key and value used for attention operators. 

111 average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads. 

112 Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect 

113 when ``need_weights=True.``. Default: True 

114 

115 

116 Shape: 

117 Inputs: 

118 - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 

119 the embedding dimension. 

120 - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 

121 the embedding dimension. 

122 - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 

123 the embedding dimension. 

124 - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. 

125 If a FloatTensor is provided, it will be directly added to the value. 

126 If a BoolTensor is provided, the positions with the 

127 value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 

128 - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 

129 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 

130 S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked 

131 positions. If a BoolTensor is provided, positions with ``True`` 

132 are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 

133 is provided, it will be added to the attention weight. 

134 - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 

135 N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 

136 - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 

137 N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 

138 

139 Outputs: 

140 - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 

141 E is the embedding dimension. 

142 - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns 

143 attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 

144 :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 

145 :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 

146 head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`. 

147 """ 

148 if key_padding_mask is not None: 

149 raise NotImplementedError("key_padding_mask is not supported yet") 

150 if attn_mask is not None: 

151 raise NotImplementedError("attn_mask is not supported yet") 

152 

153 is_batched = F._mha_shape_check( 

154 query, key, value, key_padding_mask, attn_mask, num_heads 

155 ) 

156 

157 # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input 

158 # is batched, run the computation and before returning squeeze the 

159 # batch dimension so that the output doesn't carry this temporary batch dimension. 

160 if not is_batched: 

161 # unsqueeze if the input is unbatched 

162 query = query.unsqueeze(1) 

163 key = key.unsqueeze(1) 

164 value = value.unsqueeze(1) 

165 if key_padding_mask is not None: 

166 key_padding_mask = key_padding_mask.unsqueeze(0) 

167 

168 # set up shape vars 

169 tgt_len, bsz, embed_dim = query.shape 

170 src_len, _, _ = key.shape 

171 

172 key_padding_mask = F._canonical_mask( 

173 mask=key_padding_mask, 

174 mask_name="key_padding_mask", 

175 other_type=F._none_or_dtype(attn_mask), 

176 other_name="attn_mask", 

177 target_type=query.dtype, 

178 ) 

179 

180 if is_causal and attn_mask is None: 

181 raise RuntimeError( 

182 "Need attn_mask if specifying the is_causal hint. " 

183 "You may use the Transformer module method " 

184 "`generate_square_subsequent_mask` to create this mask." 

185 ) 

186 

187 if is_causal and key_padding_mask is None and not need_weights: 

188 # when we have a kpm or need weights, we need attn_mask 

189 # Otherwise, we use the is_causal hint go as is_causal 

190 # indicator to SDPA. 

191 attn_mask = None 

192 else: 

193 attn_mask = F._canonical_mask( 

194 mask=attn_mask, 

195 mask_name="attn_mask", 

196 other_type=None, 

197 other_name="", 

198 target_type=query.dtype, 

199 check_other=False, 

200 ) 

201 

202 if key_padding_mask is not None: 

203 # We have the attn_mask, and use that to merge kpm into it. 

204 # Turn off use of is_causal hint, as the merged mask is no 

205 # longer causal. 

206 is_causal = False 

207 

208 assert ( 

209 embed_dim == embed_dim_to_check 

210 ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 

211 if isinstance(embed_dim, torch.Tensor): 

212 # embed_dim can be a tensor when JIT tracing 

213 head_dim = embed_dim.div(num_heads, rounding_mode="trunc") 

214 else: 

215 head_dim = embed_dim // num_heads 

216 assert ( 

217 head_dim * num_heads == embed_dim 

218 ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 

219 if use_separate_proj_weight: 

220 # allow MHA to have different embedding dimensions when separate projection weights are used 

221 assert ( 

222 key.shape[:2] == value.shape[:2] 

223 ), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 

224 else: 

225 assert ( 

226 key.shape == value.shape 

227 ), f"key shape {key.shape} does not match value shape {value.shape}" 

228 

229 # 

230 # compute in-projection 

231 # 

232 if not use_separate_proj_weight: 

233 assert ( 

234 in_proj_weight is not None 

235 ), "use_separate_proj_weight is False but in_proj_weight is None" 

236 q, k, v = F._in_projection_packed( 

237 query, key, value, in_proj_weight, in_proj_bias 

238 ) 

239 else: 

240 assert ( 

241 q_proj_weight is not None 

242 ), "use_separate_proj_weight is True but q_proj_weight is None" 

243 assert ( 

244 k_proj_weight is not None 

245 ), "use_separate_proj_weight is True but k_proj_weight is None" 

246 assert ( 

247 v_proj_weight is not None 

248 ), "use_separate_proj_weight is True but v_proj_weight is None" 

249 if in_proj_bias is None: 

250 b_q = b_k = b_v = None 

251 else: 

252 b_q, b_k, b_v = in_proj_bias.chunk(3) 

253 q, k, v = F._in_projection( 

254 query, 

255 key, 

256 value, 

257 q_proj_weight, 

258 k_proj_weight, 

259 v_proj_weight, 

260 b_q, 

261 b_k, 

262 b_v, 

263 ) 

264 

265 # prep attention mask 

266 

267 if attn_mask is not None: 

268 # ensure attn_mask's dim is 3 

269 if attn_mask.dim() == 2: 

270 correct_2d_size = (tgt_len, src_len) 

271 if attn_mask.shape != correct_2d_size: 

272 raise RuntimeError( 

273 f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." 

274 ) 

275 attn_mask = attn_mask.unsqueeze(0) 

276 elif attn_mask.dim() == 3: 

277 correct_3d_size = (bsz * num_heads, tgt_len, src_len) 

278 if attn_mask.shape != correct_3d_size: 

279 raise RuntimeError( 

280 f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." 

281 ) 

282 else: 

283 raise RuntimeError( 

284 f"attn_mask's dimension {attn_mask.dim()} is not supported" 

285 ) 

286 

287 # add bias along batch dimension (currently second) 

288 if bias_k is not None and bias_v is not None: 

289 assert static_k is None, "bias cannot be added to static key." 

290 assert static_v is None, "bias cannot be added to static value." 

291 k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 

292 v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 

293 if attn_mask is not None: 

294 attn_mask = F.pad(attn_mask, (0, 1)) 

295 if key_padding_mask is not None: 

296 key_padding_mask = F.pad(key_padding_mask, (0, 1)) 

297 else: 

298 assert bias_k is None 

299 assert bias_v is None 

300 

301 # 

302 # reshape q, k, v for multihead attention and make them batch first 

303 # 

304 q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 

305 if static_k is None: 

306 k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 

307 else: 

308 # TODO finish disentangling control flow so we don't do in-projections when statics are passed 

309 assert ( 

310 static_k.size(0) == bsz * num_heads 

311 ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 

312 assert ( 

313 static_k.size(2) == head_dim 

314 ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 

315 k = static_k 

316 if static_v is None: 

317 v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 

318 else: 

319 # TODO finish disentangling control flow so we don't do in-projections when statics are passed 

320 assert ( 

321 static_v.size(0) == bsz * num_heads 

322 ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 

323 assert ( 

324 static_v.size(2) == head_dim 

325 ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 

326 v = static_v 

327 

328 # add zero attention along batch dimension (now first) 

329 if add_zero_attn: 

330 zero_attn_shape = (bsz * num_heads, 1, head_dim) 

331 k = torch.cat( 

332 [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1 

333 ) 

334 v = torch.cat( 

335 [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1 

336 ) 

337 if attn_mask is not None: 

338 attn_mask = F.pad(attn_mask, (0, 1)) 

339 if key_padding_mask is not None: 

340 key_padding_mask = F.pad(key_padding_mask, (0, 1)) 

341 

342 # update source sequence length after adjustments 

343 src_len = k.size(1) 

344 

345 # merge key padding and attention masks 

346 if key_padding_mask is not None: 

347 if not torch.jit.is_scripting() and not torch.jit.is_tracing(): 

348 F._check_key_padding_mask(key_padding_mask, src_len, bsz) 

349 

350 key_padding_mask = ( 

351 key_padding_mask.view(bsz, 1, 1, src_len) 

352 .expand(-1, num_heads, -1, -1) 

353 .reshape(bsz * num_heads, 1, src_len) 

354 ) 

355 if attn_mask is None: 

356 attn_mask = key_padding_mask 

357 else: 

358 attn_mask = attn_mask + key_padding_mask 

359 

360 # adjust dropout probability 

361 if not training: 

362 dropout_p = 0.0 

363 

364 # 

365 # (deep breath) calculate attention and out projection 

366 # 

367 

368 # This is the "if need_weights" from the original pytorch code 

369 # We just adapt the case where the weights are needed 

370 # Indeed, since we are using specific implementations for computing 

371 # attention for the complex valued case, we cannot use the optimized versions 

372 # of the original pytorch code (flash attention or others) 

373 _B, _Nt, E = q.shape 

374 q_scaled = q * math.sqrt(1.0 / float(E)) 

375 

376 assert not ( 

377 is_causal and attn_mask is None 

378 ), "FIXME: is_causal not implemented for need_weights" 

379 

380 # For equation (8) from (Eilers et al. 2023), 

381 # We need to conjugate the keys before computing the dot product 

382 k = k.conj() 

383 

384 if attn_mask is not None: 

385 attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1)) 

386 else: 

387 attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) 

388 

389 # And then take the real part of the result 

390 attn_output_weights = attn_output_weights.real 

391 

392 attn_output_weights = F.softmax(attn_output_weights, dim=-1) 

393 if dropout_p > 0.0: 

394 attn_output_weights = dropout(attn_output_weights, p=dropout_p) 

395 

396 # attn_output_weights are real valued while v are complex valued 

397 attn_output = torch.bmm(attn_output_weights.to(v.dtype), v) 

398 

399 attn_output = ( 

400 attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) 

401 ) 

402 attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 

403 attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) 

404 

405 # Early exist if we do not need the weights 

406 if not need_weights: 

407 if not is_batched: 

408 # squeeze the output if input was unbatched 

409 attn_output = attn_output.squeeze(1) 

410 return attn_output, None 

411 

412 # Perform the extra computation only if the weights are needed 

413 

414 # optionally average attention weights over heads 

415 attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 

416 if average_attn_weights: 

417 attn_output_weights = attn_output_weights.mean(dim=1) 

418 

419 if not is_batched: 

420 # squeeze the output if input was unbatched 

421 attn_output = attn_output.squeeze(1) 

422 attn_output_weights = attn_output_weights.squeeze(0) 

423 return attn_output, attn_output_weights