Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/activation.py: 95%
138 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 06:39 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 06:39 +0000
1# MIT License
3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Optional
26# External imports
27import torch
28import torch.nn as nn
29import torch.nn.functional as F
31# Local imports
32from torchcvnn.nn import functional as c_F
33from .initialization import complex_xavier_uniform_
36class IndependentRealImag(nn.Module):
37 """
38 Generic module to apply a real valued activation function independently
39 on both the real and imaginary part
41 Arguments:
42 fact: A nn.Module name of a real valued activation function
43 """
45 def __init__(self, fact: nn.Module):
46 super().__init__()
47 self.act_real = fact()
48 self.act_imag = fact()
50 def forward(self, z: torch.tensor) -> torch.tensor:
51 """
52 Performs the forward pass
54 Arguments:
55 z: the input tensor on which to apply the activation function
56 """
57 return self.act_real(z.real) + self.act_imag(z.imag) * 1j
60class CReLU(IndependentRealImag):
61 """
62 Applies a ReLU independently on both the real and imaginary parts
64 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j`
66 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to
67 :math:`0`. Otherwise either the real and/or the imaginary part is preserved.
69 """
71 def __init__(self) -> None:
72 super().__init__(nn.ReLU)
75class CPReLU(IndependentRealImag):
76 """
77 Applies a PReLU independently on both the real and imaginary parts
79 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j`
80 """
82 def __init__(self) -> None:
83 super().__init__(nn.PReLU)
86class CELU(IndependentRealImag):
87 """
88 Applies a ELU independently on both the real and imaginary parts
90 Not to confuse with `torch.nn.CELU`. For the complex equivalent of
91 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU`
93 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j`
94 """
96 def __init__(self) -> None:
97 super().__init__(nn.ELU)
100class CCELU(IndependentRealImag):
101 """
102 Applies a CELU independently on both the real and imaginary parts
104 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j`
105 """
107 def __init__(self) -> None:
108 super().__init__(nn.CELU)
111class CGELU(IndependentRealImag):
112 """
113 Applies a GELU independently on both the real and imaginary parts
115 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j`
116 """
118 def __init__(self) -> None:
119 super().__init__(nn.GELU)
122class CSigmoid(IndependentRealImag):
123 """
124 Applies a Sigmoid independently on both the real and imaginary parts
126 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997.
128 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j`
130 where the real valued sigmoid is applied in the right hand side terms.
131 """
133 def __init__(self) -> None:
134 super().__init__(nn.Sigmoid)
137class CTanh(IndependentRealImag):
138 """
139 Applies a Tanh independently on both the real and imaginary parts
141 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j`
143 where the real valued sigmoid is applied in the right hand side terms.
144 """
146 def __init__(self) -> None:
147 super().__init__(nn.Tanh)
150class zReLU(nn.Module):
151 r"""
152 Applies a zReLU
154 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}`
156 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are
157 projected to :math:`0`. In other words, only one quadrant is preserved.
158 """
160 def __init__(self):
161 super().__init__()
163 def forward(self, z: torch.Tensor):
164 """
165 Performs the forward pass.
167 Arguments:
168 z: the input tensor on which to apply the activation function
169 """
170 pos_real = z.real > 0
171 pos_img = z.imag > 0
172 return z * pos_real * pos_img
175class zAbsReLU(nn.Module):
176 r"""
177 Applies a zAbsReLU
179 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}`
181 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is
182 trainable.
183 """
185 def __init__(self):
186 super().__init__()
187 self.a = torch.nn.parameter.Parameter(
188 data=torch.Tensor([1.0]), requires_grad=True
189 )
191 def forward(self, z: torch.Tensor):
192 """
193 Performs the forward pass.
195 Arguments:
196 z: the input tensor on which to apply the activation function
197 """
198 mask = z.abs() < self.a
199 return z * mask
202class zLeakyReLU(nn.Module):
203 r"""
204 Applies a zLeakyReLU
206 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}`
208 """
210 def __init__(self):
211 super().__init__()
212 self.a = torch.nn.parameter.Parameter(
213 data=torch.Tensor([0.2]), requires_grad=True
214 )
216 def forward(self, z: torch.Tensor):
217 """
218 Performs the forward pass.
220 Arguments:
221 z: the input tensor on which to apply the activation function
222 """
223 pos_real = z.real > 0
224 pos_img = z.imag > 0
225 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img))
228class Mod(nn.Module):
229 r"""
230 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}`
232 :math:`Mod(z) = |z|`
234 This activation function allows to go from complex values to real
235 values.
237 """
239 def __init__(self):
240 super().__init__()
242 def forward(self, z: torch.Tensor):
243 """
244 Performs the forward pass.
246 Arguments:
247 z: the input tensor on which to apply the activation function
248 """
249 return torch.abs(z)
252class modReLU(nn.Module):
253 r"""
254 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged.
256 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}`
257 """
259 def __init__(self):
260 super().__init__()
261 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True)
263 def forward(self, z: torch.Tensor):
264 """
265 Performs the forward pass.
267 Arguments:
268 z: the input tensor on which to apply the activation function
269 """
270 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle())
273class Cardioid(nn.Module):
274 r"""
275 The cardioid activation function as proposed by Virtue et al. (2019) is given by :
277 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z`
279 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU :
281 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)`
282 """
284 def __init__(self):
285 super().__init__()
287 def forward(self, z: torch.Tensor):
288 """
289 Performs the forward pass.
291 Arguments:
292 z: the input tensor on which to apply the activation function
293 """
294 return 0.5 * (1 + torch.cos(z.angle())) * z
297class MultiheadAttention(nn.Module):
298 """
300 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors.
302 Allows the model to jointly attend to information from different
303 representation subspaces as described in the paper
304 [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
306 .. math::
307 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O
309 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)`
312 This implementation is based on the paper **Building blocks for a complex-valued
313 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech,
314 and Signal Processing (ICASSP).
316 Attention is defined as follows:
318 .. math::
320 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V
322 Arguments:
323 embed_dim: Total dimension of the model.
324 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`)
325 dropout: Dropout probability on `attn_output_weights`. Default: `0.0`
326 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim`
327 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim`
328 batch_first: If `True`, then the input and output tensors are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature)
331 Example:
333 .. code-block:: python
335 import torchcvnn as c_nn
336 import torch
338 nhead = 8
339 seq_len = 10
340 batch_size = 32
341 num_features = 512
343 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead)
344 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64)
345 attn_output, attn_output_weights = multihead_attn(src, src, src)
346 # attn_output is (seq_len, batch_size, numè_features)
348 """
350 def __init__(
351 self,
352 embed_dim: int,
353 num_heads: int,
354 dropout: float = 0.0,
355 bias: bool = True,
356 add_bias_kv=False,
357 add_zero_attn=False,
358 kdim: int = None,
359 vdim: int = None,
360 batch_first: bool = False,
361 device: torch.device = None,
362 dtype: torch.dtype = torch.complex64,
363 ):
364 factory_kwargs = {"device": device, "dtype": dtype}
365 super().__init__()
366 self.embed_dim = embed_dim
367 self.kdim = kdim if kdim is not None else embed_dim
368 self.vdim = vdim if vdim is not None else embed_dim
369 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
371 self.num_heads = num_heads
372 self.dropout = dropout
373 self.batch_first = batch_first
374 self.head_dim = embed_dim // num_heads
375 assert (
376 self.head_dim * num_heads == self.embed_dim
377 ), "embed_dim must be divisible by num_heads"
379 if not self._qkv_same_embed_dim:
380 self.q_proj_weight = torch.nn.parameter.Parameter(
381 torch.empty((embed_dim, embed_dim), **factory_kwargs)
382 )
383 self.k_proj_weight = torch.nn.parameter.Parameter(
384 torch.empty((embed_dim, self.kdim), **factory_kwargs)
385 )
386 self.v_proj_weight = torch.nn.parameter.Parameter(
387 torch.empty((embed_dim, self.vdim), **factory_kwargs)
388 )
389 self.register_parameter("in_proj_weight", None)
390 else:
391 self.in_proj_weight = torch.nn.parameter.Parameter(
392 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
393 )
394 self.register_parameter("q_proj_weight", None)
395 self.register_parameter("k_proj_weight", None)
396 self.register_parameter("v_proj_weight", None)
398 if bias:
399 self.in_proj_bias = torch.nn.parameter.Parameter(
400 torch.empty(3 * embed_dim, **factory_kwargs)
401 )
402 else:
403 self.register_parameter("in_proj_bias", None)
405 self.out_proj = torch.nn.Linear(
406 embed_dim, embed_dim, bias=bias, **factory_kwargs
407 )
409 if add_bias_kv:
410 self.bias_k = torch.nn.parameter.Parameter(
411 torch.empty((1, 1, embed_dim), **factory_kwargs)
412 )
413 self.bias_v = torch.nn.parameter.Parameter(
414 torch.empty((1, 1, embed_dim), **factory_kwargs)
415 )
416 else:
417 self.bias_k = self.bias_v = None
419 self.add_zero_attn = add_zero_attn
420 if bias:
421 self.in_proj_bias = torch.nn.parameter.Parameter(
422 torch.empty(3 * embed_dim, **factory_kwargs)
423 )
425 self._reset_parameters()
427 def _reset_parameters(self):
428 if self._qkv_same_embed_dim:
429 complex_xavier_uniform_(self.in_proj_weight)
430 else:
431 complex_xavier_uniform_(self.q_proj_weight)
432 complex_xavier_uniform_(self.k_proj_weight)
433 complex_xavier_uniform_(self.v_proj_weight)
435 if self.in_proj_bias is not None:
436 torch.nn.init.constant_(self.in_proj_bias, 0.0)
437 torch.nn.init.constant_(self.out_proj.bias, 0.0)
438 if self.bias_k is not None:
439 torch.nn.init.constant_(self.bias_k, 0.0)
440 if self.bias_v is not None:
441 torch.nn.init.constant_(self.bias_v, 0.0)
443 def forward(
444 self,
445 query: torch.Tensor,
446 key: torch.Tensor,
447 value: torch.Tensor,
448 key_padding_mask: Optional[torch.Tensor] = None,
449 need_weights: bool = True,
450 attn_mask: Optional[torch.Tensor] = None,
451 average_attn_weights: bool = True,
452 is_causal: bool = False,
453 ) -> torch.Tensor:
454 """
455 Computes attention outputs using query, key and value embeddings.
457 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors. It keeps the same
458 signature but does not support yet key_padding_mask and attn_mask.
459 """
461 is_batched = query.dim() == 3
463 if key_padding_mask is not None:
464 raise NotImplementedError("key_padding_mask is not supported yet")
465 # key_padding_mask = F._canonical_mask(
466 # mask=key_padding_mask,
467 # mask_name="key_padding_mask",
468 # other_type=F._none_or_dtype(attn_mask),
469 # other_name="attn_mask",
470 # target_type=query.dtype, # Adapted because q is complex
471 # )
472 # But
473 # F._canonical_mask raises an exception
474 # AssertionError: only bool and floating types of key_padding_mask are supported
476 if attn_mask is not None:
477 raise NotImplementedError("attn_mask is not supported yet")
478 # attn_mask = F._canonical_mask(
479 # mask=attn_mask,
480 # mask_name="attn_mask",
481 # other_type=None,
482 # other_name="",
483 # target_type=query.dtype, # Adapted because q is complex
484 # check_other=False,
485 # )
487 if self.batch_first and is_batched:
488 # These steps prevent multiple transpose on the same tensors
489 # for example when using self-attention
490 if key is value:
491 if query is key:
492 query = key = value = query.transpose(1, 0)
493 else:
494 query, key = (x.transpose(1, 0) for x in (query, key))
495 value = key
496 else:
497 query, key, value = (x.transpose(1, 0) for x in (query, key, value))
499 if not self._qkv_same_embed_dim:
500 attn_output, attn_output_weights = c_F.multi_head_attention_forward(
501 query,
502 key,
503 value,
504 self.embed_dim,
505 self.num_heads,
506 self.in_proj_weight,
507 self.in_proj_bias,
508 self.bias_k,
509 self.bias_v,
510 self.add_zero_attn,
511 self.dropout,
512 self.out_proj.weight,
513 self.out_proj.bias,
514 training=self.training,
515 key_padding_mask=key_padding_mask,
516 need_weights=need_weights,
517 attn_mask=attn_mask,
518 use_separate_proj_weight=True,
519 q_proj_weight=self.q_proj_weight,
520 k_proj_weight=self.k_proj_weight,
521 v_proj_weight=self.v_proj_weight,
522 average_attn_weights=average_attn_weights,
523 is_causal=is_causal,
524 )
525 else:
526 attn_output, attn_output_weights = c_F.multi_head_attention_forward(
527 query,
528 key,
529 value,
530 self.embed_dim,
531 self.num_heads,
532 self.in_proj_weight,
533 self.in_proj_bias,
534 self.bias_k,
535 self.bias_v,
536 self.add_zero_attn,
537 self.dropout,
538 self.out_proj.weight,
539 self.out_proj.bias,
540 training=self.training,
541 key_padding_mask=key_padding_mask,
542 need_weights=need_weights,
543 attn_mask=attn_mask,
544 average_attn_weights=average_attn_weights,
545 is_causal=is_causal,
546 )
547 if self.batch_first and is_batched:
548 return attn_output.transpose(1, 0), attn_output_weights
549 else:
550 return attn_output, attn_output_weights