Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/activation.py: 95%

139 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.functional as F 

30 

31# Local imports 

32from torchcvnn.nn import functional as c_F 

33from .initialization import complex_xavier_uniform_ 

34 

35 

36class IndependentRealImag(nn.Module): 

37 """ 

38 Generic module to apply a real valued activation function independently 

39 on both the real and imaginary part 

40 

41 Arguments: 

42 fact: A nn.Module name of a real valued activation function 

43 """ 

44 

45 def __init__(self, fact: nn.Module): 

46 super().__init__() 

47 self.act_real = fact() 

48 self.act_imag = fact() 

49 

50 def forward(self, z: torch.tensor) -> torch.tensor: 

51 """ 

52 Performs the forward pass 

53 

54 Arguments: 

55 z: the input tensor on which to apply the activation function 

56 """ 

57 return self.act_real(z.real) + self.act_imag(z.imag) * 1j 

58 

59 

60class CReLU(IndependentRealImag): 

61 """ 

62 Applies a ReLU independently on both the real and imaginary parts 

63 

64 :math:`CReLU(z) = ReLU(\\Re[z]) + ReLU(\\Im[z])j` 

65 

66 Only the quadrant where both `\\Re[z]` and `\\Im[z]` are negative is projected to 

67 :math:`0`. Otherwise either the real and/or the imaginary part is preserved. 

68 

69 """ 

70 

71 def __init__(self) -> None: 

72 super().__init__(nn.ReLU) 

73 

74 

75class CPReLU(IndependentRealImag): 

76 """ 

77 Applies a PReLU independently on both the real and imaginary parts 

78 

79 :math:`CPReLU(z) = PReLU(\\Re[z]) + PReLU(\\Im[z])j` 

80 """ 

81 

82 def __init__(self) -> None: 

83 super().__init__(nn.PReLU) 

84 

85 

86class CELU(IndependentRealImag): 

87 """ 

88 Applies a ELU independently on both the real and imaginary parts 

89 

90 Not to confuse with `torch.nn.CELU`. For the complex equivalent of 

91 :external:py:class:`torch.nn.CELU`, see :class:`torchcvnn.nn.modules.activation.CCELU` 

92 

93 :math:`CELU(z) = ELU(\\Re[z]) + ELU(\\Im[z])j` 

94 """ 

95 

96 def __init__(self) -> None: 

97 super().__init__(nn.ELU) 

98 

99 

100class CCELU(IndependentRealImag): 

101 """ 

102 Applies a CELU independently on both the real and imaginary parts 

103 

104 :math:`CCELU(z) = CELU(\\Re[z]) + CELU(\\Im[z])j` 

105 """ 

106 

107 def __init__(self) -> None: 

108 super().__init__(nn.CELU) 

109 

110 

111class CGELU(IndependentRealImag): 

112 """ 

113 Applies a GELU independently on both the real and imaginary parts 

114 

115 :math:`CGELU(z) = GELU(\\Re[z]) + GELU(\\Im[z])j` 

116 """ 

117 

118 def __init__(self) -> None: 

119 super().__init__(nn.GELU) 

120 

121 

122class CSigmoid(IndependentRealImag): 

123 """ 

124 Applies a Sigmoid independently on both the real and imaginary parts 

125 

126 as used in Nitta Tohru. An extension of the back-propagation algorithm to complex numbers. Neural Networks, 10(9):1391–1415, November 1997. 

127 

128 :math:`CSigmoid(z) = Sigmoid(\\Re[z]) + Sigmoid(\\Im[z])j` 

129 

130 where the real valued sigmoid is applied in the right hand side terms. 

131 """ 

132 

133 def __init__(self) -> None: 

134 super().__init__(nn.Sigmoid) 

135 

136 

137class CTanh(IndependentRealImag): 

138 """ 

139 Applies a Tanh independently on both the real and imaginary parts 

140 

141 :math:`CTanh(z) = \\tanh(\\Re[z]) + \\tanh(\\Im[z])j` 

142 

143 where the real valued sigmoid is applied in the right hand side terms. 

144 """ 

145 

146 def __init__(self) -> None: 

147 super().__init__(nn.Tanh) 

148 

149 

150class zReLU(nn.Module): 

151 r""" 

152 Applies a zReLU 

153 

154 :math:`zReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ 0 & \mbox{otherwise} \end{cases}` 

155 

156 All the quadrant where both :math:`\Re[z]` and :math:`\Im[z]` are non negative are 

157 projected to :math:`0`. In other words, only one quadrant is preserved. 

158 """ 

159 

160 def __init__(self): 

161 super().__init__() 

162 

163 def forward(self, z: torch.Tensor): 

164 """ 

165 Performs the forward pass. 

166 

167 Arguments: 

168 z: the input tensor on which to apply the activation function 

169 """ 

170 pos_real = z.real > 0 

171 pos_img = z.imag > 0 

172 return z * pos_real * pos_img 

173 

174 

175class zAbsReLU(nn.Module): 

176 r""" 

177 Applies a zAbsReLU 

178 

179 :math:`zAbsReLU(z) = \begin{cases} z & \mbox{if } |z| \geq a\\ 0 & \mbox{otherwise} \end{cases}` 

180 

181 This cancels all the complex plane in the circle of radius :math:`a`, where :math:`a` is 

182 trainable. 

183 """ 

184 

185 def __init__(self): 

186 super().__init__() 

187 self.a = torch.nn.parameter.Parameter( 

188 data=torch.Tensor([1.0]), requires_grad=True 

189 ) 

190 

191 def forward(self, z: torch.Tensor): 

192 """ 

193 Performs the forward pass. 

194 

195 Arguments: 

196 z: the input tensor on which to apply the activation function 

197 """ 

198 mask = z.abs() < self.a 

199 return z * mask 

200 

201 

202class zLeakyReLU(nn.Module): 

203 r""" 

204 Applies a zLeakyReLU 

205 

206 :math:`zLeakyReLU(z) = \begin{cases} z & \mbox{if } \Re[z] > 0 \mbox{ and } \Im[z] > 0\\ a.z & \mbox{otherwise} \end{cases}` 

207 

208 """ 

209 

210 def __init__(self): 

211 super().__init__() 

212 self.a = torch.nn.parameter.Parameter( 

213 data=torch.Tensor([0.2]), requires_grad=True 

214 ) 

215 

216 def forward(self, z: torch.Tensor): 

217 """ 

218 Performs the forward pass. 

219 

220 Arguments: 

221 z: the input tensor on which to apply the activation function 

222 """ 

223 pos_real = z.real > 0 

224 pos_img = z.imag > 0 

225 return z * pos_real * pos_img + self.a * (z * ~(pos_real * pos_img)) 

226 

227 

228class Mod(nn.Module): 

229 r""" 

230 Extracts the magnitude of the complex input. It maps to :math:`\mathbb{R}` 

231 

232 :math:`Mod(z) = |z|` 

233 

234 This activation function allows to go from complex values to real 

235 values. 

236 

237 """ 

238 

239 def __init__(self): 

240 super().__init__() 

241 

242 def forward(self, z: torch.Tensor): 

243 """ 

244 Performs the forward pass. 

245 

246 Arguments: 

247 z: the input tensor on which to apply the activation function 

248 """ 

249 return torch.abs(z) 

250 

251 

252class modReLU(nn.Module): 

253 r""" 

254 Applies a ReLU with parametric offset on the amplitude, keeping the phase unchanged. 

255 

256 :math:`modReLU(z) = ReLU(|z| + b) e^{j \theta}` 

257 """ 

258 

259 def __init__(self): 

260 super().__init__() 

261 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True) 

262 

263 def forward(self, z: torch.Tensor): 

264 """ 

265 Performs the forward pass. 

266 

267 Arguments: 

268 z: the input tensor on which to apply the activation function 

269 """ 

270 return nn.functional.relu(z.abs() + self.b) * torch.exp(1j * z.angle()) 

271 

272 

273class Cardioid(nn.Module): 

274 r""" 

275 The cardioid activation function as proposed by Virtue et al. (2019) is given by : 

276 

277 :math:`Cardioid(z) = \frac{1+\cos(\theta)}{2} z` 

278 

279 For real numbers, e.g. :math:`\theta \in \{0, \pi\}`, it reduces to the ReLU : 

280 

281 :math:`\forall r \in \mathbb{R}, \theta \in \{0, \pi\}, Cardioid(r e^{j \theta}) = ReLU(r) e^{j \theta} = ReLU(r)` 

282 """ 

283 

284 def __init__(self): 

285 super().__init__() 

286 self.b = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float), True) 

287 

288 def forward(self, z: torch.Tensor): 

289 """ 

290 Performs the forward pass. 

291 

292 Arguments: 

293 z: the input tensor on which to apply the activation function 

294 """ 

295 return 0.5 * (1 + torch.cos(z.angle())) * z 

296 

297 

298class MultiheadAttention(nn.Module): 

299 """ 

300 

301 This class is adapted from torch.nn.MultiheadAttention to support complex valued tensors. 

302 

303 Allows the model to jointly attend to information from different 

304 representation subspaces as described in the paper 

305 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 

306 

307 .. math:: 

308 \mbox{MultiHead}(Q, K, V) = [head_1, \dots, head_h] W^O 

309 

310 where :math:`head_i = \mbox{Attention}(Q W^Q_i, KW^K_i, VW^V_i)` 

311 

312 

313 This implementation is based on the paper **Building blocks for a complex-valued 

314 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech, 

315 and Signal Processing (ICASSP). 

316 

317 Attention is defined as follows: 

318 

319 .. math:: 

320 

321 \mbox{Attention}(Q, K, V) = \sigma(\\Re[\\frac{Q K^H}{\sqrt{d_k}}])V 

322 

323 Arguments: 

324 embed_dim: Total dimension of the model. 

325 num_heads: Number of parallel heads. Note that `embed_dim` will be split accross `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`) 

326 dropout: Dropout probability on `attn_output_weights`. Default: `0.0` 

327 kdim: Total number of features for keys. Default `None` which uses `kdim=embed_dim` 

328 vdim: Total number of features for keys. Default `None` which uses `vdim=embed_dim` 

329 batch_first: If `True`, then the input and output tensors are provided as (batch, seq, feature). Default `False` with tensors as (seq, batch, feature) 

330 

331 

332 Example: 

333 

334 .. code-block:: python 

335 

336 import torchcvnn as c_nn 

337 import torch 

338 

339 nhead = 8 

340 seq_len = 10 

341 batch_size = 32 

342 num_features = 512 

343 

344 multihead_attn = c_nn.MultiheadAttention(embed_dim=num_features, num_heads=nhead) 

345 src = torch.rand(seq_len, batch_size, num_features, dtype=torch.complex64) 

346 attn_output, attn_output_weights = multihead_attn(src, src, src) 

347 # attn_output is (seq_len, batch_size, numè_features) 

348 

349 """ 

350 

351 def __init__( 

352 self, 

353 embed_dim: int, 

354 num_heads: int, 

355 dropout: float = 0.0, 

356 bias: bool = True, 

357 add_bias_kv=False, 

358 add_zero_attn=False, 

359 kdim: int = None, 

360 vdim: int = None, 

361 batch_first: bool = False, 

362 device: torch.device = None, 

363 dtype: torch.dtype = torch.complex64, 

364 ): 

365 factory_kwargs = {"device": device, "dtype": dtype} 

366 super().__init__() 

367 self.embed_dim = embed_dim 

368 self.kdim = kdim if kdim is not None else embed_dim 

369 self.vdim = vdim if vdim is not None else embed_dim 

370 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 

371 

372 self.num_heads = num_heads 

373 self.dropout = dropout 

374 self.batch_first = batch_first 

375 self.head_dim = embed_dim // num_heads 

376 assert ( 

377 self.head_dim * num_heads == self.embed_dim 

378 ), "embed_dim must be divisible by num_heads" 

379 

380 if not self._qkv_same_embed_dim: 

381 self.q_proj_weight = torch.nn.parameter.Parameter( 

382 torch.empty((embed_dim, embed_dim), **factory_kwargs) 

383 ) 

384 self.k_proj_weight = torch.nn.parameter.Parameter( 

385 torch.empty((embed_dim, self.kdim), **factory_kwargs) 

386 ) 

387 self.v_proj_weight = torch.nn.parameter.Parameter( 

388 torch.empty((embed_dim, self.vdim), **factory_kwargs) 

389 ) 

390 self.register_parameter("in_proj_weight", None) 

391 else: 

392 self.in_proj_weight = torch.nn.parameter.Parameter( 

393 torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 

394 ) 

395 self.register_parameter("q_proj_weight", None) 

396 self.register_parameter("k_proj_weight", None) 

397 self.register_parameter("v_proj_weight", None) 

398 

399 if bias: 

400 self.in_proj_bias = torch.nn.parameter.Parameter( 

401 torch.empty(3 * embed_dim, **factory_kwargs) 

402 ) 

403 else: 

404 self.register_parameter("in_proj_bias", None) 

405 

406 self.out_proj = torch.nn.Linear( 

407 embed_dim, embed_dim, bias=bias, **factory_kwargs 

408 ) 

409 

410 if add_bias_kv: 

411 self.bias_k = torch.nn.parameter.Parameter( 

412 torch.empty((1, 1, embed_dim), **factory_kwargs) 

413 ) 

414 self.bias_v = torch.nn.parameter.Parameter( 

415 torch.empty((1, 1, embed_dim), **factory_kwargs) 

416 ) 

417 else: 

418 self.bias_k = self.bias_v = None 

419 

420 self.add_zero_attn = add_zero_attn 

421 if bias: 

422 self.in_proj_bias = torch.nn.parameter.Parameter( 

423 torch.empty(3 * embed_dim, **factory_kwargs) 

424 ) 

425 

426 self._reset_parameters() 

427 

428 def _reset_parameters(self): 

429 if self._qkv_same_embed_dim: 

430 complex_xavier_uniform_(self.in_proj_weight) 

431 else: 

432 complex_xavier_uniform_(self.q_proj_weight) 

433 complex_xavier_uniform_(self.k_proj_weight) 

434 complex_xavier_uniform_(self.v_proj_weight) 

435 

436 if self.in_proj_bias is not None: 

437 torch.nn.init.constant_(self.in_proj_bias, 0.0) 

438 torch.nn.init.constant_(self.out_proj.bias, 0.0) 

439 if self.bias_k is not None: 

440 torch.nn.init.constant_(self.bias_k, 0.0) 

441 if self.bias_v is not None: 

442 torch.nn.init.constant_(self.bias_v, 0.0) 

443 

444 def forward( 

445 self, 

446 query: torch.Tensor, 

447 key: torch.Tensor, 

448 value: torch.Tensor, 

449 key_padding_mask: Optional[torch.Tensor] = None, 

450 need_weights: bool = True, 

451 attn_mask: Optional[torch.Tensor] = None, 

452 average_attn_weights: bool = True, 

453 is_causal: bool = False, 

454 ) -> torch.Tensor: 

455 """ 

456 Computes attention outputs using query, key and value embeddings. 

457 

458 This function is adapted from torch.nn.MultiheadAttention to support complex valued tensors. It keeps the same 

459 signature but does not support yet key_padding_mask and attn_mask. 

460 """ 

461 

462 is_batched = query.dim() == 3 

463 

464 if key_padding_mask is not None: 

465 raise NotImplementedError("key_padding_mask is not supported yet") 

466 # key_padding_mask = F._canonical_mask( 

467 # mask=key_padding_mask, 

468 # mask_name="key_padding_mask", 

469 # other_type=F._none_or_dtype(attn_mask), 

470 # other_name="attn_mask", 

471 # target_type=query.dtype, # Adapted because q is complex 

472 # ) 

473 # But 

474 # F._canonical_mask raises an exception 

475 # AssertionError: only bool and floating types of key_padding_mask are supported 

476 

477 if attn_mask is not None: 

478 raise NotImplementedError("attn_mask is not supported yet") 

479 # attn_mask = F._canonical_mask( 

480 # mask=attn_mask, 

481 # mask_name="attn_mask", 

482 # other_type=None, 

483 # other_name="", 

484 # target_type=query.dtype, # Adapted because q is complex 

485 # check_other=False, 

486 # ) 

487 

488 if self.batch_first and is_batched: 

489 # These steps prevent multiple transpose on the same tensors 

490 # for example when using self-attention 

491 if key is value: 

492 if query is key: 

493 query = key = value = query.transpose(1, 0) 

494 else: 

495 query, key = (x.transpose(1, 0) for x in (query, key)) 

496 value = key 

497 else: 

498 query, key, value = (x.transpose(1, 0) for x in (query, key, value)) 

499 

500 if not self._qkv_same_embed_dim: 

501 attn_output, attn_output_weights = c_F.multi_head_attention_forward( 

502 query, 

503 key, 

504 value, 

505 self.embed_dim, 

506 self.num_heads, 

507 self.in_proj_weight, 

508 self.in_proj_bias, 

509 self.bias_k, 

510 self.bias_v, 

511 self.add_zero_attn, 

512 self.dropout, 

513 self.out_proj.weight, 

514 self.out_proj.bias, 

515 training=self.training, 

516 key_padding_mask=key_padding_mask, 

517 need_weights=need_weights, 

518 attn_mask=attn_mask, 

519 use_separate_proj_weight=True, 

520 q_proj_weight=self.q_proj_weight, 

521 k_proj_weight=self.k_proj_weight, 

522 v_proj_weight=self.v_proj_weight, 

523 average_attn_weights=average_attn_weights, 

524 is_causal=is_causal, 

525 ) 

526 else: 

527 attn_output, attn_output_weights = c_F.multi_head_attention_forward( 

528 query, 

529 key, 

530 value, 

531 self.embed_dim, 

532 self.num_heads, 

533 self.in_proj_weight, 

534 self.in_proj_bias, 

535 self.bias_k, 

536 self.bias_v, 

537 self.add_zero_attn, 

538 self.dropout, 

539 self.out_proj.weight, 

540 self.out_proj.bias, 

541 training=self.training, 

542 key_padding_mask=key_padding_mask, 

543 need_weights=need_weights, 

544 attn_mask=attn_mask, 

545 average_attn_weights=average_attn_weights, 

546 is_causal=is_causal, 

547 ) 

548 if self.batch_first and is_batched: 

549 return attn_output.transpose(1, 0), attn_output_weights 

550 else: 

551 return attn_output, attn_output_weights