Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / activation.py: 99%
113 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
1# MIT License
3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Callable
26# External imports
27import torch
28import torch.nn as nn
29import torch.nn.functional as F
31# Local imports
32from torchcvnn.nn import functional as c_F
33from torchcvnn.nn.modules.normalization import RMSNorm, LayerNorm
34from .initialization import complex_xavier_uniform_
37class IndependentRealImag(nn.Module):
38 """
39 Generic module to apply a real valued activation function independently
40 on both the real and imaginary part
42 Arguments:
43 fact: A nn.Module name of a real valued activation function
44 """
46 def __init__(self, fact: nn.Module):
47 super().__init__()
48 self.act_real = fact()
49 self.act_imag = fact()
51 def forward(self, z: torch.tensor) -> torch.tensor:
52 """
53 Performs the forward pass
55 Arguments:
56 z: the input tensor on which to apply the activation function
57 """
58 return self.act_real(z.real) + self.act_imag(z.imag) * 1j
61class CReLU(IndependentRealImag):
62 """
63 Applies a ReLU independently on both the real and imaginary parts
65 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j`
67 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to
68 :math:`0`. Otherwise either the real and/or the imaginary part is preserved.
70 """
72 def __init__(self) -> None:
73 super().__init__(nn.ReLU)
76class CPReLU(IndependentRealImag):
77 """
78 Applies a PReLU independently on both the real and imaginary parts
80 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j`
81 """
83 def __init__(self) -> None:
84 super().__init__(nn.PReLU)
87class CELU(IndependentRealImag):
88 """
89 Applies a ELU independently on both the real and imaginary parts
91 Not to confuse with `torch.nn.CELU`. For the complex equivalent of
92 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU`
94 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j`
95 """
97 def __init__(self) -> None:
98 super().__init__(nn.ELU)
101class CCELU(IndependentRealImag):
102 """
103 Applies a CELU independently on both the real and imaginary parts
105 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j`
106 """
108 def __init__(self) -> None:
109 super().__init__(nn.CELU)
112class CGELU(IndependentRealImag):
113 """
114 Applies a GELU independently on both the real and imaginary parts
116 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j`
117 """
119 def __init__(self) -> None:
120 super().__init__(nn.GELU)
123class CSigmoid(IndependentRealImag):
124 """
125 Applies a Sigmoid independently on both the real and imaginary parts
127 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997.
129 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j`
131 where the real valued sigmoid is applied in the right hand side terms.
132 """
134 def __init__(self) -> None:
135 super().__init__(nn.Sigmoid)
138class CTanh(IndependentRealImag):
139 """
140 Applies a Tanh independently on both the real and imaginary parts
142 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j`
144 where the real valued sigmoid is applied in the right hand side terms.
145 """
147 def __init__(self) -> None:
148 super().__init__(nn.Tanh)
151class zReLU(nn.Module):
152 r"""
153 Applies a zReLU
155 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}`
157 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are
158 projected to :math:`0`. In other words, only one quadrant is preserved.
159 """
161 def __init__(self):
162 super().__init__()
164 def forward(self, z: torch.Tensor):
165 """
166 Performs the forward pass.
168 Arguments:
169 z: the input tensor on which to apply the activation function
170 """
171 pos_real = z.real > 0
172 pos_img = z.imag > 0
173 return z * pos_real * pos_img
176class zAbsReLU(nn.Module):
177 r"""
178 Applies a zAbsReLU
180 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}`
182 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is
183 trainable.
184 """
186 def __init__(self):
187 super().__init__()
188 self.a = torch.nn.parameter.Parameter(
189 data=torch.Tensor([1.0]), requires_grad=True
190 )
192 def forward(self, z: torch.Tensor):
193 """
194 Performs the forward pass.
196 Arguments:
197 z: the input tensor on which to apply the activation function
198 """
199 mask = z.abs() < self.a
200 return z * mask
203class zLeakyReLU(nn.Module):
204 r"""
205 Applies a zLeakyReLU
207 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}`
209 """
211 def __init__(self):
212 super().__init__()
213 self.a = torch.nn.parameter.Parameter(
214 data=torch.Tensor([0.2]), requires_grad=True
215 )
217 def forward(self, z: torch.Tensor):
218 """
219 Performs the forward pass.
221 Arguments:
222 z: the input tensor on which to apply the activation function
223 """
224 pos_real = z.real > 0
225 pos_img = z.imag > 0
226 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img))
229class Mod(nn.Module):
230 r"""
231 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}`
233 :math:`Mod(z) = |z|`
235 This activation function allows to go from complex values to real
236 values.
238 """
240 def __init__(self):
241 super().__init__()
243 def forward(self, z: torch.Tensor):
244 """
245 Performs the forward pass.
247 Arguments:
248 z: the input tensor on which to apply the activation function
249 """
250 return torch.abs(z)
253class modReLU(nn.Module):
254 r"""
255 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged.
257 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}`
258 """
260 def __init__(self):
261 super().__init__()
262 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True)
264 def forward(self, z: torch.Tensor):
265 """
266 Performs the forward pass.
268 Arguments:
269 z: the input tensor on which to apply the activation function
270 """
271 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle())
274class Cardioid(nn.Module):
275 r"""
276 The cardioid activation function as proposed by Virtue et al. (2019) is given by :
278 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z`
280 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU :
282 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)`
283 """
285 def __init__(self):
286 super().__init__()
288 def forward(self, z: torch.Tensor):
289 """
290 Performs the forward pass.
292 Arguments:
293 z: the input tensor on which to apply the activation function
294 """
295 return 0.5 * (1 + torch.cos(z.angle())) * z
298class MultiheadAttention(nn.Module):
299 """
301 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors.
303 Allows the model to jointly attend to information from different
304 representation subspaces as described in the paper
305 [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
307 .. math::
308 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O
310 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)`
313 This implementation is based on the paper **Building blocks for a complex-valued
314 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech,
315 and Signal Processing (ICASSP).
317 Attention is defined as follows:
319 .. math::
321 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V
323 Arguments:
324 embed_dim: Total dimension of the model.
325 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`)
326 dropout: Dropout probability on `attn_output_weights`. Default: `0.0`
327 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim`
328 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim`
329 batch_first: If `True`, then the input (query, key, value) and output tensors (attn_outputs) are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature)
332 Example:
334 .. code-block:: python
336 import torchcvnn as c_nn
337 import torch
339 nhead = 8
340 seq_len = 10
341 batch_size = 32
342 num_features = 512
344 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead)
345 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64)
346 attn_output, attn_output_weights = multihead_attn(src, src, src)
347 # attn_output is (seq_len, batch_size, num_features)
349 """
351 def __init__(
352 self,
353 embed_dim: int,
354 num_heads: int,
355 dropout: float = 0.0,
356 norm_layer: Callable[..., nn.Module] = LayerNorm,
357 bias: bool = True,
358 batch_first: bool = False,
359 device: torch.device = None,
360 dtype: torch.dtype = torch.complex64,
361 ):
362 factory_kwargs = {"device": device, "dtype": dtype}
363 super().__init__()
364 self.embed_dim = embed_dim
366 self.num_heads = num_heads
367 self.dropout = dropout
368 self.batch_first = batch_first
369 self.head_dim = embed_dim // num_heads
370 assert (
371 self.head_dim * num_heads == self.embed_dim
372 ), "embed_dim must be divisible by num_heads"
374 self.q_norm = norm_layer(self.head_dim)
375 self.k_norm = norm_layer(self.head_dim)
377 self.in_proj_weight = torch.nn.parameter.Parameter(
378 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
379 )
381 if bias:
382 self.in_proj_bias = torch.nn.parameter.Parameter(
383 torch.empty(3 * embed_dim, **factory_kwargs)
384 )
385 else:
386 self.register_parameter("in_proj_bias", None)
388 self.out_proj = torch.nn.Linear(
389 embed_dim, embed_dim, bias=bias, **factory_kwargs
390 )
392 self._reset_parameters()
394 def _reset_parameters(self):
395 complex_xavier_uniform_(self.in_proj_weight)
396 if self.in_proj_bias is not None:
397 torch.nn.init.constant_(self.in_proj_bias, 0.0)
399 complex_xavier_uniform_(self.out_proj.weight)
400 if self.out_proj.bias is not None:
401 torch.nn.init.constant_(self.out_proj.bias, 0.0)
403 def forward(
404 self,
405 query: torch.Tensor,
406 key: torch.Tensor,
407 value: torch.Tensor,
408 need_weights: bool = True,
409 average_attn_weights: bool = True,
410 ) -> torch.Tensor:
411 """
412 Computes attention outputs using query, key and value embeddings.
414 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors.
416 Shape:
417 Inputs:
418 - query: :math:`(T, E)` or :math:`(T, B, E)` (``batch_first=False``) or :math:`(B, T, E) (``batch_first=True``), where T is the target sequence length, B is the batch size, E is the embedding dimension
419 - key: :math:`(S, E)` or :math:`(S, B, E)` (``batch_first=False``) or :math:`(B, S, E) (``batch_first=True``), where S is the source sequence length, B is the batch size, E is the embedding dimension.
420 - value: :math:`(S, E)` or :math:`(S, B, E)` (``batch_first=False``) or :math:`(B, S, E) (``batch_first=True``), where S is the source sequence length, B is the batch size, E is the embedding dimension.
422 Outputs:
423 - attn_output: :math:`(T, E)` or :math:`(T, B, E)` (``batch_first=False``) or :math:`(B, T, E) (``batch_first=True``), where T is the target sequence length, B is the batch size, E is the embedding dimension
424 - attn_output_weights :math:`(T, S)` or :math:`(B, T, S)` Optional output, not available if need_weights=False
425 """
427 is_batched = query.dim() == 3
429 if self.batch_first and is_batched:
430 # In this case, query is (B, T, E), key is (B, S, E) and value is (B, S, E)
432 # These steps prevent multiple transpose on the same tensors
433 # for example when using self-attention
434 if key is value:
435 if query is key:
436 query = key = value = query.transpose(1, 0)
437 else:
438 query, key = (x.transpose(1, 0) for x in (query, key))
439 value = key
440 else:
441 query, key, value = (x.transpose(1, 0) for x in (query, key, value)) # (T, B, E), (S, B, E), (S, B, E)
443 attn_output, attn_output_weights = c_F.multi_head_attention_forward(
444 query=query,
445 key=key,
446 value=value,
447 embed_dim_to_check=self.embed_dim,
448 num_heads=self.num_heads,
449 q_norm=self.q_norm,
450 k_norm=self.k_norm,
451 in_proj_weight=self.in_proj_weight,
452 in_proj_bias=self.in_proj_bias,
453 dropout_p=self.dropout,
454 out_proj=self.out_proj,
455 training=self.training,
456 need_weights=need_weights,
457 average_attn_weights=average_attn_weights,
458 )
459 # attn_output is (T, E) or (T, B, E)
460 # attn_output_weights is (T, S) or (B, T, S) (already batch_first)
461 if is_batched and self.batch_first:
462 return attn_output.transpose(1, 0)
464 if need_weights:
465 return attn_output, attn_output_weights
466 else:
467 return attn_output