Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/activation.py: 95%
139 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Optional
26# External imports
27import torch
28import torch.nn as nn
29import torch.nn.functional as F
31# Local imports
32from torchcvnn.nn import functional as c_F
33from .initialization import complex_xavier_uniform_
36class IndependentRealImag(nn.Module):
37 """
38 Generic module to apply a real valued activation function independently
39 on both the real and imaginary part
41 Arguments:
42 fact: A nn.Module name of a real valued activation function
43 """
45 def __init__(self, fact: nn.Module):
46 super().__init__()
47 self.act_real = fact()
48 self.act_imag = fact()
50 def forward(self, z: torch.tensor) -> torch.tensor:
51 """
52 Performs the forward pass
54 Arguments:
55 z: the input tensor on which to apply the activation function
56 """
57 return self.act_real(z.real) + self.act_imag(z.imag) * 1j
60class CReLU(IndependentRealImag):
61 """
62 Applies a ReLU independently on both the real and imaginary parts
64 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j`
66 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to
67 :math:`0`. Otherwise either the real and/or the imaginary part is preserved.
69 """
71 def __init__(self) -> None:
72 super().__init__(nn.ReLU)
75class CPReLU(IndependentRealImag):
76 """
77 Applies a PReLU independently on both the real and imaginary parts
79 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j`
80 """
82 def __init__(self) -> None:
83 super().__init__(nn.PReLU)
86class CELU(IndependentRealImag):
87 """
88 Applies a ELU independently on both the real and imaginary parts
90 Not to confuse with `torch.nn.CELU`. For the complex equivalent of
91 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU`
93 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j`
94 """
96 def __init__(self) -> None:
97 super().__init__(nn.ELU)
100class CCELU(IndependentRealImag):
101 """
102 Applies a CELU independently on both the real and imaginary parts
104 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j`
105 """
107 def __init__(self) -> None:
108 super().__init__(nn.CELU)
111class CGELU(IndependentRealImag):
112 """
113 Applies a GELU independently on both the real and imaginary parts
115 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j`
116 """
118 def __init__(self) -> None:
119 super().__init__(nn.GELU)
122class CSigmoid(IndependentRealImag):
123 """
124 Applies a Sigmoid independently on both the real and imaginary parts
126 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997.
128 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j`
130 where the real valued sigmoid is applied in the right hand side terms.
131 """
133 def __init__(self) -> None:
134 super().__init__(nn.Sigmoid)
137class CTanh(IndependentRealImag):
138 """
139 Applies a Tanh independently on both the real and imaginary parts
141 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j`
143 where the real valued sigmoid is applied in the right hand side terms.
144 """
146 def __init__(self) -> None:
147 super().__init__(nn.Tanh)
150class zReLU(nn.Module):
151 r"""
152 Applies a zReLU
154 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}`
156 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are
157 projected to :math:`0`. In other words, only one quadrant is preserved.
158 """
160 def __init__(self):
161 super().__init__()
163 def forward(self, z: torch.Tensor):
164 """
165 Performs the forward pass.
167 Arguments:
168 z: the input tensor on which to apply the activation function
169 """
170 pos_real = z.real > 0
171 pos_img = z.imag > 0
172 return z * pos_real * pos_img
175class zAbsReLU(nn.Module):
176 r"""
177 Applies a zAbsReLU
179 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}`
181 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is
182 trainable.
183 """
185 def __init__(self):
186 super().__init__()
187 self.a = torch.nn.parameter.Parameter(
188 data=torch.Tensor([1.0]), requires_grad=True
189 )
191 def forward(self, z: torch.Tensor):
192 """
193 Performs the forward pass.
195 Arguments:
196 z: the input tensor on which to apply the activation function
197 """
198 mask = z.abs() < self.a
199 return z * mask
202class zLeakyReLU(nn.Module):
203 r"""
204 Applies a zLeakyReLU
206 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}`
208 """
210 def __init__(self):
211 super().__init__()
212 self.a = torch.nn.parameter.Parameter(
213 data=torch.Tensor([0.2]), requires_grad=True
214 )
216 def forward(self, z: torch.Tensor):
217 """
218 Performs the forward pass.
220 Arguments:
221 z: the input tensor on which to apply the activation function
222 """
223 pos_real = z.real > 0
224 pos_img = z.imag > 0
225 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img))
228class Mod(nn.Module):
229 r"""
230 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}`
232 :math:`Mod(z) = |z|`
234 This activation function allows to go from complex values to real
235 values.
237 """
239 def __init__(self):
240 super().__init__()
242 def forward(self, z: torch.Tensor):
243 """
244 Performs the forward pass.
246 Arguments:
247 z: the input tensor on which to apply the activation function
248 """
249 return torch.abs(z)
252class modReLU(nn.Module):
253 r"""
254 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged.
256 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}`
257 """
259 def __init__(self):
260 super().__init__()
261 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True)
263 def forward(self, z: torch.Tensor):
264 """
265 Performs the forward pass.
267 Arguments:
268 z: the input tensor on which to apply the activation function
269 """
270 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle())
273class Cardioid(nn.Module):
274 r"""
275 The cardioid activation function as proposed by Virtue et al. (2019) is given by :
277 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z`
279 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU :
281 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)`
282 """
284 def __init__(self):
285 super().__init__()
286 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True)
288 def forward(self, z: torch.Tensor):
289 """
290 Performs the forward pass.
292 Arguments:
293 z: the input tensor on which to apply the activation function
294 """
295 return 0.5 * (1 + torch.cos(z.angle())) * z
298class MultiheadAttention(nn.Module):
299 """
301 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors.
303 Allows the model to jointly attend to information from different
304 representation subspaces as described in the paper
305 [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
307 .. math::
308 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O
310 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)`
313 This implementation is based on the paper **Building blocks for a complex-valued
314 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech,
315 and Signal Processing (ICASSP).
317 Attention is defined as follows:
319 .. math::
321 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V
323 Arguments:
324 embed_dim: Total dimension of the model.
325 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`)
326 dropout: Dropout probability on `attn_output_weights`. Default: `0.0`
327 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim`
328 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim`
329 batch_first: If `True`, then the input and output tensors are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature)
332 Example:
334 .. code-block:: python
336 import torchcvnn as c_nn
337 import torch
339 nhead = 8
340 seq_len = 10
341 batch_size = 32
342 num_features = 512
344 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead)
345 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64)
346 attn_output, attn_output_weights = multihead_attn(src, src, src)
347 # attn_output is (seq_len, batch_size, numè_features)
349 """
351 def __init__(
352 self,
353 embed_dim: int,
354 num_heads: int,
355 dropout: float = 0.0,
356 bias: bool = True,
357 add_bias_kv=False,
358 add_zero_attn=False,
359 kdim: int = None,
360 vdim: int = None,
361 batch_first: bool = False,
362 device: torch.device = None,
363 dtype: torch.dtype = torch.complex64,
364 ):
365 factory_kwargs = {"device": device, "dtype": dtype}
366 super().__init__()
367 self.embed_dim = embed_dim
368 self.kdim = kdim if kdim is not None else embed_dim
369 self.vdim = vdim if vdim is not None else embed_dim
370 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
372 self.num_heads = num_heads
373 self.dropout = dropout
374 self.batch_first = batch_first
375 self.head_dim = embed_dim // num_heads
376 assert (
377 self.head_dim * num_heads == self.embed_dim
378 ), "embed_dim must be divisible by num_heads"
380 if not self._qkv_same_embed_dim:
381 self.q_proj_weight = torch.nn.parameter.Parameter(
382 torch.empty((embed_dim, embed_dim), **factory_kwargs)
383 )
384 self.k_proj_weight = torch.nn.parameter.Parameter(
385 torch.empty((embed_dim, self.kdim), **factory_kwargs)
386 )
387 self.v_proj_weight = torch.nn.parameter.Parameter(
388 torch.empty((embed_dim, self.vdim), **factory_kwargs)
389 )
390 self.register_parameter("in_proj_weight", None)
391 else:
392 self.in_proj_weight = torch.nn.parameter.Parameter(
393 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
394 )
395 self.register_parameter("q_proj_weight", None)
396 self.register_parameter("k_proj_weight", None)
397 self.register_parameter("v_proj_weight", None)
399 if bias:
400 self.in_proj_bias = torch.nn.parameter.Parameter(
401 torch.empty(3 * embed_dim, **factory_kwargs)
402 )
403 else:
404 self.register_parameter("in_proj_bias", None)
406 self.out_proj = torch.nn.Linear(
407 embed_dim, embed_dim, bias=bias, **factory_kwargs
408 )
410 if add_bias_kv:
411 self.bias_k = torch.nn.parameter.Parameter(
412 torch.empty((1, 1, embed_dim), **factory_kwargs)
413 )
414 self.bias_v = torch.nn.parameter.Parameter(
415 torch.empty((1, 1, embed_dim), **factory_kwargs)
416 )
417 else:
418 self.bias_k = self.bias_v = None
420 self.add_zero_attn = add_zero_attn
421 if bias:
422 self.in_proj_bias = torch.nn.parameter.Parameter(
423 torch.empty(3 * embed_dim, **factory_kwargs)
424 )
426 self._reset_parameters()
428 def _reset_parameters(self):
429 if self._qkv_same_embed_dim:
430 complex_xavier_uniform_(self.in_proj_weight)
431 else:
432 complex_xavier_uniform_(self.q_proj_weight)
433 complex_xavier_uniform_(self.k_proj_weight)
434 complex_xavier_uniform_(self.v_proj_weight)
436 if self.in_proj_bias is not None:
437 torch.nn.init.constant_(self.in_proj_bias, 0.0)
438 torch.nn.init.constant_(self.out_proj.bias, 0.0)
439 if self.bias_k is not None:
440 torch.nn.init.constant_(self.bias_k, 0.0)
441 if self.bias_v is not None:
442 torch.nn.init.constant_(self.bias_v, 0.0)
444 def forward(
445 self,
446 query: torch.Tensor,
447 key: torch.Tensor,
448 value: torch.Tensor,
449 key_padding_mask: Optional[torch.Tensor] = None,
450 need_weights: bool = True,
451 attn_mask: Optional[torch.Tensor] = None,
452 average_attn_weights: bool = True,
453 is_causal: bool = False,
454 ) -> torch.Tensor:
455 """
456 Computes attention outputs using query, key and value embeddings.
458 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors. It keeps the same
459 signature but does not support yet key_padding_mask and attn_mask.
460 """
462 is_batched = query.dim() == 3
464 if key_padding_mask is not None:
465 raise NotImplementedError("key_padding_mask is not supported yet")
466 # key_padding_mask = F._canonical_mask(
467 # mask=key_padding_mask,
468 # mask_name="key_padding_mask",
469 # other_type=F._none_or_dtype(attn_mask),
470 # other_name="attn_mask",
471 # target_type=query.dtype, # Adapted because q is complex
472 # )
473 # But
474 # F._canonical_mask raises an exception
475 # AssertionError: only bool and floating types of key_padding_mask are supported
477 if attn_mask is not None:
478 raise NotImplementedError("attn_mask is not supported yet")
479 # attn_mask = F._canonical_mask(
480 # mask=attn_mask,
481 # mask_name="attn_mask",
482 # other_type=None,
483 # other_name="",
484 # target_type=query.dtype, # Adapted because q is complex
485 # check_other=False,
486 # )
488 if self.batch_first and is_batched:
489 # These steps prevent multiple transpose on the same tensors
490 # for example when using self-attention
491 if key is value:
492 if query is key:
493 query = key = value = query.transpose(1, 0)
494 else:
495 query, key = (x.transpose(1, 0) for x in (query, key))
496 value = key
497 else:
498 query, key, value = (x.transpose(1, 0) for x in (query, key, value))
500 if not self._qkv_same_embed_dim:
501 attn_output, attn_output_weights = c_F.multi_head_attention_forward(
502 query,
503 key,
504 value,
505 self.embed_dim,
506 self.num_heads,
507 self.in_proj_weight,
508 self.in_proj_bias,
509 self.bias_k,
510 self.bias_v,
511 self.add_zero_attn,
512 self.dropout,
513 self.out_proj.weight,
514 self.out_proj.bias,
515 training=self.training,
516 key_padding_mask=key_padding_mask,
517 need_weights=need_weights,
518 attn_mask=attn_mask,
519 use_separate_proj_weight=True,
520 q_proj_weight=self.q_proj_weight,
521 k_proj_weight=self.k_proj_weight,
522 v_proj_weight=self.v_proj_weight,
523 average_attn_weights=average_attn_weights,
524 is_causal=is_causal,
525 )
526 else:
527 attn_output, attn_output_weights = c_F.multi_head_attention_forward(
528 query,
529 key,
530 value,
531 self.embed_dim,
532 self.num_heads,
533 self.in_proj_weight,
534 self.in_proj_bias,
535 self.bias_k,
536 self.bias_v,
537 self.add_zero_attn,
538 self.dropout,
539 self.out_proj.weight,
540 self.out_proj.bias,
541 training=self.training,
542 key_padding_mask=key_padding_mask,
543 need_weights=need_weights,
544 attn_mask=attn_mask,
545 average_attn_weights=average_attn_weights,
546 is_causal=is_causal,
547 )
548 if self.batch_first and is_batched:
549 return attn_output.transpose(1, 0), attn_output_weights
550 else:
551 return attn_output, attn_output_weights