Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/activation.py: 95%

138 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-05 06:39 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.functional as F 

30 

31# Local imports 

32from torchcvnn.nn import functional as c_F 

33from .initialization import complex_xavier_uniform_ 

34 

35 

36class IndependentRealImag(nn.Module): 

37 """ 

38 Generic module to apply a real valued activation function independently 

39 on both the real and imaginary part 

40 

41 Arguments: 

42 fact: A nn.Module name of a real valued activation function 

43 """ 

44 

45 def __init__(self, fact: nn.Module): 

46 super().__init__() 

47 self.act_real = fact() 

48 self.act_imag = fact() 

49 

50 def forward(self, z: torch.tensor) -> torch.tensor: 

51 """ 

52 Performs the forward pass 

53 

54 Arguments: 

55 z: the input tensor on which to apply the activation function 

56 """ 

57 return self.act_real(z.real) + self.act_imag(z.imag) * 1j 

58 

59 

60class CReLU(IndependentRealImag): 

61 """ 

62 Applies a ReLU independently on both the real and imaginary parts 

63 

64 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j` 

65 

66 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to 

67 :math:`0`. Otherwise either the real and/or the imaginary part is preserved. 

68 

69 """ 

70 

71 def __init__(self) -> None: 

72 super().__init__(nn.ReLU) 

73 

74 

75class CPReLU(IndependentRealImag): 

76 """ 

77 Applies a PReLU independently on both the real and imaginary parts 

78 

79 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j` 

80 """ 

81 

82 def __init__(self) -> None: 

83 super().__init__(nn.PReLU) 

84 

85 

86class CELU(IndependentRealImag): 

87 """ 

88 Applies a ELU independently on both the real and imaginary parts 

89 

90 Not to confuse with `torch.nn.CELU`. For the complex equivalent of 

91 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU` 

92 

93 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j` 

94 """ 

95 

96 def __init__(self) -> None: 

97 super().__init__(nn.ELU) 

98 

99 

100class CCELU(IndependentRealImag): 

101 """ 

102 Applies a CELU independently on both the real and imaginary parts 

103 

104 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j` 

105 """ 

106 

107 def __init__(self) -> None: 

108 super().__init__(nn.CELU) 

109 

110 

111class CGELU(IndependentRealImag): 

112 """ 

113 Applies a GELU independently on both the real and imaginary parts 

114 

115 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j` 

116 """ 

117 

118 def __init__(self) -> None: 

119 super().__init__(nn.GELU) 

120 

121 

122class CSigmoid(IndependentRealImag): 

123 """ 

124 Applies a Sigmoid independently on both the real and imaginary parts 

125 

126 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997. 

127 

128 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j` 

129 

130 where the real valued sigmoid is applied in the right hand side terms. 

131 """ 

132 

133 def __init__(self) -> None: 

134 super().__init__(nn.Sigmoid) 

135 

136 

137class CTanh(IndependentRealImag): 

138 """ 

139 Applies a Tanh independently on both the real and imaginary parts 

140 

141 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j` 

142 

143 where the real valued sigmoid is applied in the right hand side terms. 

144 """ 

145 

146 def __init__(self) -> None: 

147 super().__init__(nn.Tanh) 

148 

149 

150class zReLU(nn.Module): 

151 r""" 

152 Applies a zReLU 

153 

154 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}` 

155 

156 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are 

157 projected to :math:`0`. In other words, only one quadrant is preserved. 

158 """ 

159 

160 def __init__(self): 

161 super().__init__() 

162 

163 def forward(self, z: torch.Tensor): 

164 """ 

165 Performs the forward pass. 

166 

167 Arguments: 

168 z: the input tensor on which to apply the activation function 

169 """ 

170 pos_real = z.real > 0 

171 pos_img = z.imag > 0 

172 return z * pos_real * pos_img 

173 

174 

175class zAbsReLU(nn.Module): 

176 r""" 

177 Applies a zAbsReLU 

178 

179 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}` 

180 

181 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is 

182 trainable. 

183 """ 

184 

185 def __init__(self): 

186 super().__init__() 

187 self.a = torch.nn.parameter.Parameter( 

188 data=torch.Tensor([1.0]), requires_grad=True 

189 ) 

190 

191 def forward(self, z: torch.Tensor): 

192 """ 

193 Performs the forward pass. 

194 

195 Arguments: 

196 z: the input tensor on which to apply the activation function 

197 """ 

198 mask = z.abs() < self.a 

199 return z * mask 

200 

201 

202class zLeakyReLU(nn.Module): 

203 r""" 

204 Applies a zLeakyReLU 

205 

206 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}` 

207 

208 """ 

209 

210 def __init__(self): 

211 super().__init__() 

212 self.a = torch.nn.parameter.Parameter( 

213 data=torch.Tensor([0.2]), requires_grad=True 

214 ) 

215 

216 def forward(self, z: torch.Tensor): 

217 """ 

218 Performs the forward pass. 

219 

220 Arguments: 

221 z: the input tensor on which to apply the activation function 

222 """ 

223 pos_real = z.real > 0 

224 pos_img = z.imag > 0 

225 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img)) 

226 

227 

228class Mod(nn.Module): 

229 r""" 

230 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}` 

231 

232 :math:`Mod(z) = |z|` 

233 

234 This activation function allows to go from complex values to real 

235 values. 

236 

237 """ 

238 

239 def __init__(self): 

240 super().__init__() 

241 

242 def forward(self, z: torch.Tensor): 

243 """ 

244 Performs the forward pass. 

245 

246 Arguments: 

247 z: the input tensor on which to apply the activation function 

248 """ 

249 return torch.abs(z) 

250 

251 

252class modReLU(nn.Module): 

253 r""" 

254 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged. 

255 

256 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}` 

257 """ 

258 

259 def __init__(self): 

260 super().__init__() 

261 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True) 

262 

263 def forward(self, z: torch.Tensor): 

264 """ 

265 Performs the forward pass. 

266 

267 Arguments: 

268 z: the input tensor on which to apply the activation function 

269 """ 

270 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle()) 

271 

272 

273class Cardioid(nn.Module): 

274 r""" 

275 The cardioid activation function as proposed by Virtue et al. (2019) is given by : 

276 

277 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z` 

278 

279 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU : 

280 

281 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)` 

282 """ 

283 

284 def __init__(self): 

285 super().__init__() 

286 

287 def forward(self, z: torch.Tensor): 

288 """ 

289 Performs the forward pass. 

290 

291 Arguments: 

292 z: the input tensor on which to apply the activation function 

293 """ 

294 return 0.5 * (1 + torch.cos(z.angle())) * z 

295 

296 

297class MultiheadAttention(nn.Module): 

298 """ 

299 

300 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors. 

301 

302 Allows the model to jointly attend to information from different 

303 representation subspaces as described in the paper 

304 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 

305 

306 .. math:: 

307 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O 

308 

309 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)` 

310 

311 

312 This implementation is based on the paper **Building blocks for a complex-valued 

313 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech, 

314 and Signal Processing (ICASSP). 

315 

316 Attention is defined as follows: 

317 

318 .. math:: 

319 

320 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V 

321 

322 Arguments: 

323 embed_dim: Total dimension of the model. 

324 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`) 

325 dropout: Dropout probability on `attn_output_weights`. Default: `0.0` 

326 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim` 

327 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim` 

328 batch_first: If `True`, then the input and output tensors are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature) 

329 

330 

331 Example: 

332 

333 .. code-block:: python 

334 

335 import torchcvnn as c_nn 

336 import torch 

337 

338 nhead = 8 

339 seq_len = 10 

340 batch_size = 32 

341 num_features = 512 

342 

343 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead) 

344 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64) 

345 attn_output, attn_output_weights = multihead_attn(src, src, src) 

346 # attn_output is (seq_len, batch_size, numè_features) 

347 

348 """ 

349 

350 def __init__( 

351 self, 

352 embed_dim: int, 

353 num_heads: int, 

354 dropout: float = 0.0, 

355 bias: bool = True, 

356 add_bias_kv=False, 

357 add_zero_attn=False, 

358 kdim: int = None, 

359 vdim: int = None, 

360 batch_first: bool = False, 

361 device: torch.device = None, 

362 dtype: torch.dtype = torch.complex64, 

363 ): 

364 factory_kwargs = {"device": device, "dtype": dtype} 

365 super().__init__() 

366 self.embed_dim = embed_dim 

367 self.kdim = kdim if kdim is not None else embed_dim 

368 self.vdim = vdim if vdim is not None else embed_dim 

369 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 

370 

371 self.num_heads = num_heads 

372 self.dropout = dropout 

373 self.batch_first = batch_first 

374 self.head_dim = embed_dim // num_heads 

375 assert ( 

376 self.head_dim * num_heads == self.embed_dim 

377 ), "embed_dim must be divisible by num_heads" 

378 

379 if not self._qkv_same_embed_dim: 

380 self.q_proj_weight = torch.nn.parameter.Parameter( 

381 torch.empty((embed_dim, embed_dim), **factory_kwargs) 

382 ) 

383 self.k_proj_weight = torch.nn.parameter.Parameter( 

384 torch.empty((embed_dim, self.kdim), **factory_kwargs) 

385 ) 

386 self.v_proj_weight = torch.nn.parameter.Parameter( 

387 torch.empty((embed_dim, self.vdim), **factory_kwargs) 

388 ) 

389 self.register_parameter("in_proj_weight", None) 

390 else: 

391 self.in_proj_weight = torch.nn.parameter.Parameter( 

392 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 

393 ) 

394 self.register_parameter("q_proj_weight", None) 

395 self.register_parameter("k_proj_weight", None) 

396 self.register_parameter("v_proj_weight", None) 

397 

398 if bias: 

399 self.in_proj_bias = torch.nn.parameter.Parameter( 

400 torch.empty(3 * embed_dim, **factory_kwargs) 

401 ) 

402 else: 

403 self.register_parameter("in_proj_bias", None) 

404 

405 self.out_proj = torch.nn.Linear( 

406 embed_dim, embed_dim, bias=bias, **factory_kwargs 

407 ) 

408 

409 if add_bias_kv: 

410 self.bias_k = torch.nn.parameter.Parameter( 

411 torch.empty((1, 1, embed_dim), **factory_kwargs) 

412 ) 

413 self.bias_v = torch.nn.parameter.Parameter( 

414 torch.empty((1, 1, embed_dim), **factory_kwargs) 

415 ) 

416 else: 

417 self.bias_k = self.bias_v = None 

418 

419 self.add_zero_attn = add_zero_attn 

420 if bias: 

421 self.in_proj_bias = torch.nn.parameter.Parameter( 

422 torch.empty(3 * embed_dim, **factory_kwargs) 

423 ) 

424 

425 self._reset_parameters() 

426 

427 def _reset_parameters(self): 

428 if self._qkv_same_embed_dim: 

429 complex_xavier_uniform_(self.in_proj_weight) 

430 else: 

431 complex_xavier_uniform_(self.q_proj_weight) 

432 complex_xavier_uniform_(self.k_proj_weight) 

433 complex_xavier_uniform_(self.v_proj_weight) 

434 

435 if self.in_proj_bias is not None: 

436 torch.nn.init.constant_(self.in_proj_bias, 0.0) 

437 torch.nn.init.constant_(self.out_proj.bias, 0.0) 

438 if self.bias_k is not None: 

439 torch.nn.init.constant_(self.bias_k, 0.0) 

440 if self.bias_v is not None: 

441 torch.nn.init.constant_(self.bias_v, 0.0) 

442 

443 def forward( 

444 self, 

445 query: torch.Tensor, 

446 key: torch.Tensor, 

447 value: torch.Tensor, 

448 key_padding_mask: Optional[torch.Tensor] = None, 

449 need_weights: bool = True, 

450 attn_mask: Optional[torch.Tensor] = None, 

451 average_attn_weights: bool = True, 

452 is_causal: bool = False, 

453 ) -> torch.Tensor: 

454 """ 

455 Computes attention outputs using query, key and value embeddings. 

456 

457 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors. It keeps the same 

458 signature but does not support yet key_padding_mask and attn_mask. 

459 """ 

460 

461 is_batched = query.dim() == 3 

462 

463 if key_padding_mask is not None: 

464 raise NotImplementedError("key_padding_mask is not supported yet") 

465 # key_padding_mask = F._canonical_mask( 

466 # mask=key_padding_mask, 

467 # mask_name="key_padding_mask", 

468 # other_type=F._none_or_dtype(attn_mask), 

469 # other_name="attn_mask", 

470 # target_type=query.dtype, # Adapted because q is complex 

471 # ) 

472 # But 

473 # F._canonical_mask raises an exception 

474 # AssertionError: only bool and floating types of key_padding_mask are supported 

475 

476 if attn_mask is not None: 

477 raise NotImplementedError("attn_mask is not supported yet") 

478 # attn_mask = F._canonical_mask( 

479 # mask=attn_mask, 

480 # mask_name="attn_mask", 

481 # other_type=None, 

482 # other_name="", 

483 # target_type=query.dtype, # Adapted because q is complex 

484 # check_other=False, 

485 # ) 

486 

487 if self.batch_first and is_batched: 

488 # These steps prevent multiple transpose on the same tensors 

489 # for example when using self-attention 

490 if key is value: 

491 if query is key: 

492 query = key = value = query.transpose(1, 0) 

493 else: 

494 query, key = (x.transpose(1, 0) for x in (query, key)) 

495 value = key 

496 else: 

497 query, key, value = (x.transpose(1, 0) for x in (query, key, value)) 

498 

499 if not self._qkv_same_embed_dim: 

500 attn_output, attn_output_weights = c_F.multi_head_attention_forward( 

501 query, 

502 key, 

503 value, 

504 self.embed_dim, 

505 self.num_heads, 

506 self.in_proj_weight, 

507 self.in_proj_bias, 

508 self.bias_k, 

509 self.bias_v, 

510 self.add_zero_attn, 

511 self.dropout, 

512 self.out_proj.weight, 

513 self.out_proj.bias, 

514 training=self.training, 

515 key_padding_mask=key_padding_mask, 

516 need_weights=need_weights, 

517 attn_mask=attn_mask, 

518 use_separate_proj_weight=True, 

519 q_proj_weight=self.q_proj_weight, 

520 k_proj_weight=self.k_proj_weight, 

521 v_proj_weight=self.v_proj_weight, 

522 average_attn_weights=average_attn_weights, 

523 is_causal=is_causal, 

524 ) 

525 else: 

526 attn_output, attn_output_weights = c_F.multi_head_attention_forward( 

527 query, 

528 key, 

529 value, 

530 self.embed_dim, 

531 self.num_heads, 

532 self.in_proj_weight, 

533 self.in_proj_bias, 

534 self.bias_k, 

535 self.bias_v, 

536 self.add_zero_attn, 

537 self.dropout, 

538 self.out_proj.weight, 

539 self.out_proj.bias, 

540 training=self.training, 

541 key_padding_mask=key_padding_mask, 

542 need_weights=need_weights, 

543 attn_mask=attn_mask, 

544 average_attn_weights=average_attn_weights, 

545 is_causal=is_causal, 

546 ) 

547 if self.batch_first and is_batched: 

548 return attn_output.transpose(1, 0), attn_output_weights 

549 else: 

550 return attn_output, attn_output_weights