Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / functional.py: 94%
48 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Optional, Tuple
25import math
27# External imports
28import torch
29from torch import Tensor
30import torch.nn.functional as F
33def dropout(
34 z: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False
35) -> Tensor:
36 if p < 0.0 or p > 1.0:
37 raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}")
38 mask = F.dropout(torch.ones(z.shape, device=z.device), p, training=training)
39 return mask * z
42def multi_head_attention_forward(
43 query: Tensor,
44 key: Tensor,
45 value: Tensor,
46 embed_dim_to_check: int,
47 num_heads: int,
48 q_norm: torch.nn.Module,
49 k_norm: torch.nn.Module,
50 in_proj_weight: Tensor,
51 in_proj_bias: Tensor,
52 dropout_p: float,
53 out_proj: Optional[torch.nn.Module],
54 training: bool = True,
55 need_weights: bool = True,
56 average_attn_weights: bool = True,
57) -> Tuple[Tensor, Optional[Tensor]]:
58 r"""Forward method for MultiHeadAttention.
60 This function is adapted from pytorch torch.nn.functional.multi_head_attention_forward
62 See :class:`torchcvnn.nn.MultiheadAttention` for details.
64 Args:
65 query, key, value: map a query and a set of key-value pairs to an output.
66 See "Attention Is All You Need" for more details.
67 embed_dim_to_check: total dimension of the model.
68 num_heads: parallel attention heads.
69 in_proj_weight, in_proj_bias: input projection weight and bias.
70 dropout_p: probability of an element to be zeroed.
71 out_proj: layer for the output projection
72 training: apply dropout if is ``True``.
73 need_weights: output attn_output_weights.
74 Default: `True`
75 Note: `needs_weight` defaults to `True`, but should be set to `False`
76 For best performance when attention weights are not needed.
77 *Setting needs_weights to `True`
78 leads to a significant performance degradation.*
79 q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
80 average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
81 Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
82 when ``need_weights=True.``. Default: True
84 Shape:
85 Inputs:
86 - query: :math:`(T, B, E)` where T is the target sequence length, B is the batch size, E is the embedding dimension.
87 - key: :math:`(S, B, E)`, where S is the source sequence length, B is the batch size, E is the embedding dimension.
88 - value: :math:`(S, B, E)` where S is the source sequence length, B is the batch size, E is the embedding dimension.
90 Outputs:
91 - attn_output: :math:`(T, B, E)` where T is the target sequence length, B is the batch size, E is the embedding dimension.
92 - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
93 attention weights averaged across heads of shape
94 :math:`(B, T, S)`, where :math:`B` is the batch size, :math:`T` is the target sequence length, and
95 :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per head of shape :math:`(B, num_heads, T, S)`.
96 """
98 assert query.dim() == 3, "Expected batched tensors"
100 # set up shape vars
101 tgt_len, bsz, embed_dim = query.shape
102 src_len, _, _ = key.shape
104 assert (
105 embed_dim == embed_dim_to_check
106 ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
108 if isinstance(embed_dim, torch.Tensor):
109 # embed_dim can be a tensor when JIT tracing
110 head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
111 else:
112 head_dim = embed_dim // num_heads
114 assert (
115 num_heads * head_dim == embed_dim
116 ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
118 assert (
119 key.shape == value.shape
120 ), f"key shape {key.shape} does not match value shape {value.shape}"
122 #
123 # compute in-projection
124 #
125 assert (
126 in_proj_weight is not None
127 ), "in_proj_weight is None"
128 q, k, v = F._in_projection_packed(
129 query, key, value, in_proj_weight, in_proj_bias
130 ) # (T, B, E), (S, B, E), (S, B, E)
131 q = q_norm(q)
132 k = k_norm(k)
134 #
135 # reshape q, k, v for multihead attention and make them batch first
136 #
137 q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, tgt_len, head_dim
138 k = k.view(src_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, src_len, head_dim
139 v = v.view(src_len, bsz * num_heads, head_dim).transpose(0, 1) # bsz * num_heads, src_len, head_dim
141 # adjust dropout probability
142 if not training:
143 dropout_p = 0.0
145 #
146 # (deep breath) calculate attention and out projection
147 #
149 # This is the "if need_weights" from the original pytorch code
150 # We just adapt the case where the weights are needed
151 # Indeed, since we are using specific implementations for computing
152 # attention for the complex valued case, we cannot use the optimized versions
153 # of the original pytorch code (flash attention or others)
154 q_scaled = q * math.sqrt(1.0 / float(head_dim))
156 # For equation (8) from (Eilers et al. 2023),
157 # We need to conjugate the keys before computing the dot product
158 k = k.conj()
160 # q_scaled # B * num_heads, tgt_len, head_dim
161 # k.transpose(-2, -1) # B * num_heads, head_dim, src_len
162 # bmm( X(B, n, m) , Y(B, m, p) ) = Z (B, n, p)
163 attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) #[B * num_heads, tgt_len, src_len]
165 # And then take the real part of the result
166 attn_output_weights = attn_output_weights.real
168 attn_output_weights = F.softmax(attn_output_weights, dim=-1) # Softmax over the src_len dimension
169 if dropout_p > 0.0:
170 attn_output_weights = dropout(attn_output_weights, p=dropout_p)
172 # attn_output_weights are real valued while v are complex valued
173 # attn_output_weights : B * num_heads, tgt_len, src_len
174 # v : B * num_heads, src_len, head_dim
175 attn_output = torch.bmm(attn_output_weights.to(v.dtype), v) # B * num_heads, tgt_len, head_dim
177 attn_output = attn_output.view(bsz, num_heads, tgt_len, head_dim) #[B, num_heads, tgt_len, head_dim]
178 attn_output = (
179 attn_output.transpose(1, 2).contiguous() # (B, tgt_len, num_heads, head_dim)
180 .view(bsz * tgt_len, embed_dim) # (B * tgt_len, num_heads * head_dim) = (B * tgt_len, embed_dim)
181 )
183 if out_proj is not None:
184 attn_output = out_proj(attn_output)
186 attn_output = (
187 attn_output.view(bsz, tgt_len, embed_dim) # B, seq_len, embed_dim
188 .transpose(0, 1) # seq_len , B, embed_dim
189 )
191 # Early exist if we do not need the weights
192 if not need_weights:
193 return attn_output, None
195 # Perform the extra computation only if the weights are needed
197 # optionally average attention weights over heads
198 attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) # (B, num_heads, tgt_len, src_len)
199 if average_attn_weights:
200 attn_output_weights = attn_output_weights.mean(dim=1) # (B, tgt_len, src_len)
202 return attn_output, attn_output_weights