Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / activation.py: 99%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 06:48 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Callable 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.functional as F 

30 

31# Local imports 

32from torchcvnn.nn import functional as c_F 

33from torchcvnn.nn.modules.normalization import RMSNorm, LayerNorm 

34from .initialization import complex_xavier_uniform_ 

35 

36 

37class IndependentRealImag(nn.Module): 

38 """ 

39 Generic module to apply a real valued activation function independently 

40 on both the real and imaginary part 

41 

42 Arguments: 

43 fact: A nn.Module name of a real valued activation function 

44 """ 

45 

46 def __init__(self, fact: nn.Module): 

47 super().__init__() 

48 self.act_real = fact() 

49 self.act_imag = fact() 

50 

51 def forward(self, z: torch.tensor) -> torch.tensor: 

52 """ 

53 Performs the forward pass 

54 

55 Arguments: 

56 z: the input tensor on which to apply the activation function 

57 """ 

58 return self.act_real(z.real) + self.act_imag(z.imag) * 1j 

59 

60 

61class CReLU(IndependentRealImag): 

62 """ 

63 Applies a ReLU independently on both the real and imaginary parts 

64 

65 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j` 

66 

67 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to 

68 :math:`0`. Otherwise either the real and/or the imaginary part is preserved. 

69 

70 """ 

71 

72 def __init__(self) -> None: 

73 super().__init__(nn.ReLU) 

74 

75 

76class CPReLU(IndependentRealImag): 

77 """ 

78 Applies a PReLU independently on both the real and imaginary parts 

79 

80 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j` 

81 """ 

82 

83 def __init__(self) -> None: 

84 super().__init__(nn.PReLU) 

85 

86 

87class CELU(IndependentRealImag): 

88 """ 

89 Applies a ELU independently on both the real and imaginary parts 

90 

91 Not to confuse with `torch.nn.CELU`. For the complex equivalent of 

92 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU` 

93 

94 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j` 

95 """ 

96 

97 def __init__(self) -> None: 

98 super().__init__(nn.ELU) 

99 

100 

101class CCELU(IndependentRealImag): 

102 """ 

103 Applies a CELU independently on both the real and imaginary parts 

104 

105 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j` 

106 """ 

107 

108 def __init__(self) -> None: 

109 super().__init__(nn.CELU) 

110 

111 

112class CGELU(IndependentRealImag): 

113 """ 

114 Applies a GELU independently on both the real and imaginary parts 

115 

116 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j` 

117 """ 

118 

119 def __init__(self) -> None: 

120 super().__init__(nn.GELU) 

121 

122 

123class CSigmoid(IndependentRealImag): 

124 """ 

125 Applies a Sigmoid independently on both the real and imaginary parts 

126 

127 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997. 

128 

129 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j` 

130 

131 where the real valued sigmoid is applied in the right hand side terms. 

132 """ 

133 

134 def __init__(self) -> None: 

135 super().__init__(nn.Sigmoid) 

136 

137 

138class CTanh(IndependentRealImag): 

139 """ 

140 Applies a Tanh independently on both the real and imaginary parts 

141 

142 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j` 

143 

144 where the real valued sigmoid is applied in the right hand side terms. 

145 """ 

146 

147 def __init__(self) -> None: 

148 super().__init__(nn.Tanh) 

149 

150 

151class zReLU(nn.Module): 

152 r""" 

153 Applies a zReLU 

154 

155 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}` 

156 

157 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are 

158 projected to :math:`0`. In other words, only one quadrant is preserved. 

159 """ 

160 

161 def __init__(self): 

162 super().__init__() 

163 

164 def forward(self, z: torch.Tensor): 

165 """ 

166 Performs the forward pass. 

167 

168 Arguments: 

169 z: the input tensor on which to apply the activation function 

170 """ 

171 pos_real = z.real > 0 

172 pos_img = z.imag > 0 

173 return z * pos_real * pos_img 

174 

175 

176class zAbsReLU(nn.Module): 

177 r""" 

178 Applies a zAbsReLU 

179 

180 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}` 

181 

182 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is 

183 trainable. 

184 """ 

185 

186 def __init__(self): 

187 super().__init__() 

188 self.a = torch.nn.parameter.Parameter( 

189 data=torch.Tensor([1.0]), requires_grad=True 

190 ) 

191 

192 def forward(self, z: torch.Tensor): 

193 """ 

194 Performs the forward pass. 

195 

196 Arguments: 

197 z: the input tensor on which to apply the activation function 

198 """ 

199 mask = z.abs() < self.a 

200 return z * mask 

201 

202 

203class zLeakyReLU(nn.Module): 

204 r""" 

205 Applies a zLeakyReLU 

206 

207 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}` 

208 

209 """ 

210 

211 def __init__(self): 

212 super().__init__() 

213 self.a = torch.nn.parameter.Parameter( 

214 data=torch.Tensor([0.2]), requires_grad=True 

215 ) 

216 

217 def forward(self, z: torch.Tensor): 

218 """ 

219 Performs the forward pass. 

220 

221 Arguments: 

222 z: the input tensor on which to apply the activation function 

223 """ 

224 pos_real = z.real > 0 

225 pos_img = z.imag > 0 

226 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img)) 

227 

228 

229class Mod(nn.Module): 

230 r""" 

231 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}` 

232 

233 :math:`Mod(z) = |z|` 

234 

235 This activation function allows to go from complex values to real 

236 values. 

237 

238 """ 

239 

240 def __init__(self): 

241 super().__init__() 

242 

243 def forward(self, z: torch.Tensor): 

244 """ 

245 Performs the forward pass. 

246 

247 Arguments: 

248 z: the input tensor on which to apply the activation function 

249 """ 

250 return torch.abs(z) 

251 

252 

253class modReLU(nn.Module): 

254 r""" 

255 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged. 

256 

257 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}` 

258 """ 

259 

260 def __init__(self): 

261 super().__init__() 

262 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True) 

263 

264 def forward(self, z: torch.Tensor): 

265 """ 

266 Performs the forward pass. 

267 

268 Arguments: 

269 z: the input tensor on which to apply the activation function 

270 """ 

271 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle()) 

272 

273 

274class Cardioid(nn.Module): 

275 r""" 

276 The cardioid activation function as proposed by Virtue et al. (2019) is given by : 

277 

278 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z` 

279 

280 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU : 

281 

282 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)` 

283 """ 

284 

285 def __init__(self): 

286 super().__init__() 

287 

288 def forward(self, z: torch.Tensor): 

289 """ 

290 Performs the forward pass. 

291 

292 Arguments: 

293 z: the input tensor on which to apply the activation function 

294 """ 

295 return 0.5 * (1 + torch.cos(z.angle())) * z 

296 

297 

298class MultiheadAttention(nn.Module): 

299 """ 

300 

301 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors. 

302 

303 Allows the model to jointly attend to information from different 

304 representation subspaces as described in the paper 

305 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 

306 

307 .. math:: 

308 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O 

309 

310 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)` 

311 

312 

313 This implementation is based on the paper **Building blocks for a complex-valued 

314 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech, 

315 and Signal Processing (ICASSP). 

316 

317 Attention is defined as follows: 

318 

319 .. math:: 

320 

321 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V 

322 

323 Arguments: 

324 embed_dim: Total dimension of the model. 

325 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`) 

326 dropout: Dropout probability on `attn_output_weights`. Default: `0.0` 

327 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim` 

328 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim` 

329 batch_first: If `True`, then the input (query, key, value) and output tensors (attn_outputs) are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature) 

330 

331 

332 Example: 

333 

334 .. code-block:: python 

335 

336 import torchcvnn as c_nn 

337 import torch 

338 

339 nhead = 8 

340 seq_len = 10 

341 batch_size = 32 

342 num_features = 512 

343 

344 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead) 

345 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64) 

346 attn_output, attn_output_weights = multihead_attn(src, src, src) 

347 # attn_output is (seq_len, batch_size, num_features) 

348 

349 """ 

350 

351 def __init__( 

352 self, 

353 embed_dim: int, 

354 num_heads: int, 

355 dropout: float = 0.0, 

356 norm_layer: Callable[..., nn.Module] = LayerNorm, 

357 bias: bool = True, 

358 batch_first: bool = False, 

359 device: torch.device = None, 

360 dtype: torch.dtype = torch.complex64, 

361 ): 

362 factory_kwargs = {"device": device, "dtype": dtype} 

363 super().__init__() 

364 self.embed_dim = embed_dim 

365 

366 self.num_heads = num_heads 

367 self.dropout = dropout 

368 self.batch_first = batch_first 

369 self.head_dim = embed_dim // num_heads 

370 assert ( 

371 self.head_dim * num_heads == self.embed_dim 

372 ), "embed_dim must be divisible by num_heads" 

373 

374 self.q_norm = norm_layer(self.head_dim) 

375 self.k_norm = norm_layer(self.head_dim) 

376 

377 self.in_proj_weight = torch.nn.parameter.Parameter( 

378 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 

379 ) 

380 

381 if bias: 

382 self.in_proj_bias = torch.nn.parameter.Parameter( 

383 torch.empty(3 * embed_dim, **factory_kwargs) 

384 ) 

385 else: 

386 self.register_parameter("in_proj_bias", None) 

387 

388 self.out_proj = torch.nn.Linear( 

389 embed_dim, embed_dim, bias=bias, **factory_kwargs 

390 ) 

391 

392 self._reset_parameters() 

393 

394 def _reset_parameters(self): 

395 complex_xavier_uniform_(self.in_proj_weight) 

396 if self.in_proj_bias is not None: 

397 torch.nn.init.constant_(self.in_proj_bias, 0.0) 

398 

399 complex_xavier_uniform_(self.out_proj.weight) 

400 if self.out_proj.bias is not None: 

401 torch.nn.init.constant_(self.out_proj.bias, 0.0) 

402 

403 def forward( 

404 self, 

405 query: torch.Tensor, 

406 key: torch.Tensor, 

407 value: torch.Tensor, 

408 need_weights: bool = True, 

409 average_attn_weights: bool = True, 

410 ) -> torch.Tensor: 

411 """ 

412 Computes attention outputs using query, key and value embeddings. 

413 

414 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors.  

415 

416 Shape: 

417 Inputs: 

418 - query: :math:`(T, E)` or :math:`(T, B, E)` (``batch_first=False``) or :math:`(B, T, E) (``batch_first=True``), where T is the target sequence length, B is the batch size, E is the embedding dimension 

419 - key: :math:`(S, E)` or :math:`(S, B, E)` (``batch_first=False``) or :math:`(B, S, E) (``batch_first=True``), where S is the source sequence length, B is the batch size, E is the embedding dimension. 

420 - value: :math:`(S, E)` or :math:`(S, B, E)` (``batch_first=False``) or :math:`(B, S, E) (``batch_first=True``), where S is the source sequence length, B is the batch size, E is the embedding dimension. 

421 

422 Outputs: 

423 - attn_output: :math:`(T, E)` or :math:`(T, B, E)` (``batch_first=False``) or :math:`(B, T, E) (``batch_first=True``), where T is the target sequence length, B is the batch size, E is the embedding dimension 

424 - attn_output_weights :math:`(T, S)` or :math:`(B, T, S)` Optional output, not available if need_weights=False 

425 """ 

426 

427 is_batched = query.dim() == 3 

428 

429 if self.batch_first and is_batched: 

430 # In this case, query is (B, T, E), key is (B, S, E) and value is (B, S, E) 

431 

432 # These steps prevent multiple transpose on the same tensors 

433 # for example when using self-attention 

434 if key is value: 

435 if query is key: 

436 query = key = value = query.transpose(1, 0) 

437 else: 

438 query, key = (x.transpose(1, 0) for x in (query, key)) 

439 value = key 

440 else: 

441 query, key, value = (x.transpose(1, 0) for x in (query, key, value)) # (T, B, E), (S, B, E), (S, B, E) 

442 

443 attn_output, attn_output_weights = c_F.multi_head_attention_forward( 

444 query=query, 

445 key=key, 

446 value=value, 

447 embed_dim_to_check=self.embed_dim, 

448 num_heads=self.num_heads, 

449 q_norm=self.q_norm, 

450 k_norm=self.k_norm, 

451 in_proj_weight=self.in_proj_weight, 

452 in_proj_bias=self.in_proj_bias, 

453 dropout_p=self.dropout, 

454 out_proj=self.out_proj, 

455 training=self.training, 

456 need_weights=need_weights, 

457 average_attn_weights=average_attn_weights, 

458 ) 

459 # attn_output is (T, E) or (T, B, E) 

460 # attn_output_weights is (T, S) or (B, T, S) (already batch_first) 

461 if is_batched and self.batch_first: 

462 return attn_output.transpose(1, 0) 

463 

464 if need_weights: 

465 return attn_output, attn_output_weights 

466 else: 

467 return attn_output