Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/transformer.py: 95%

118 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Union, Callable, Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.functional as F 

30from torch.nn.modules.transformer import ( 

31 _get_clones, 

32 _get_seq_len, 

33 _detect_is_causal_mask, 

34 TransformerEncoder, 

35 TransformerDecoder, 

36) 

37 

38# Local imports 

39from .activation import CReLU, MultiheadAttention 

40from .dropout import Dropout 

41from .normalization import LayerNorm 

42from .initialization import complex_xavier_uniform_ 

43 

44 

45class TransformerEncoderLayer(nn.Module): 

46 r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 

47 

48 This class is adapted from pytorch :py:class:`torch.nn.TransformerEncoderLayer` 

49 

50 This standard encoder layer is based on the paper **Attention Is All You Need**. 

51 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 

52 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 

53 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 

54 in a different way during application. 

55 

56 If you are implementing a custom layer, you may derive it either from 

57 the Module or TransformerEncoderLayer class. 

58 

59 Args: 

60 d_model: the number of expected features in the input (required). 

61 nhead: the number of heads in the multiheadattention models (required). 

62 dim_feedforward: the dimension of the feedforward network model (default=2048). 

63 dropout: the dropout value (default=0.1). 

64 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU` 

65 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

66 batch_first: If ``True``, then the input and output tensors are provided 

67 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

68 norm_first: if ``True``, layer norm is done prior to attention and feedforward 

69 operations, respectively. Otherwise it's done after. Default: ``False`` (after). 

70 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

71 bias. Default: ``True``. 

72 

73 Examples: 

74 

75 .. code-block:: python 

76 

77 import torchcvnn as c_nn 

78 import torch 

79 

80 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8) 

81 src = torch.rand(10, 32, 512, dtype=torch.complex64) 

82 out = encoder_layer(src) 

83 

84 Alternatively, when ``batch_first`` is ``True``: 

85 

86 .. code-block:: python 

87 

88 import torchcvnn as c_nn 

89 import torch 

90 

91 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 

92 src = torch.rand(32, 10, 512, dtype=torch.complex64) 

93 out = encoder_layer(src) 

94 

95 """ 

96 

97 def __init__( 

98 self, 

99 d_model: int, 

100 nhead: int, 

101 dim_feedforward: int = 2048, 

102 dropout: float = 0.1, 

103 activation: nn.Module = CReLU, 

104 layer_norm_eps: float = 1e-5, 

105 batch_first: bool = False, 

106 norm_first: bool = False, 

107 bias: bool = True, 

108 device: torch.device = None, 

109 dtype: torch.dtype = torch.complex64, 

110 attn_module=MultiheadAttention, 

111 ) -> None: 

112 factory_kwargs = {"device": device, "dtype": dtype} 

113 super().__init__() 

114 self.self_attn = attn_module( 

115 d_model, 

116 nhead, 

117 dropout=dropout, 

118 bias=bias, 

119 batch_first=batch_first, 

120 **factory_kwargs, 

121 ) 

122 # Implementation of Feedforward model 

123 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 

124 self.dropout = Dropout(dropout) 

125 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 

126 

127 self.norm_first = norm_first 

128 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

129 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

130 self.dropout1 = Dropout(dropout) 

131 self.dropout2 = Dropout(dropout) 

132 

133 self.activation = activation() 

134 

135 self._reset_parameters() 

136 

137 def _reset_parameters(self): 

138 complex_xavier_uniform_(self.linear1.weight) 

139 if self.linear1.bias is not None: 

140 nn.init.constant_(self.linear1.bias, 0) 

141 complex_xavier_uniform_(self.linear2.weight) 

142 if self.linear2.bias is not None: 

143 nn.init.constant_(self.linear2.bias, 0) 

144 

145 def __setstate__(self, state): 

146 super().__setstate__(state) 

147 if not hasattr(self, "activation"): 

148 self.activation = CReLU() 

149 

150 def forward( 

151 self, 

152 src: torch.Tensor, 

153 src_mask: Optional[torch.Tensor] = None, 

154 src_key_padding_mask: Optional[torch.Tensor] = None, 

155 is_causal: bool = False, 

156 ) -> torch.Tensor: 

157 

158 x = src 

159 if self.norm_first: 

160 x = x + self._sa_block( 

161 self.norm1(x), src_mask, src_key_padding_mask, is_causal 

162 ) 

163 x = x + self._ff_block(self.norm2(x)) 

164 else: 

165 x = x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal) 

166 x = self.norm1(x) 

167 x = x + self._ff_block(x) 

168 x = self.norm2(x) 

169 

170 return x 

171 

172 def _sa_block( 

173 self, 

174 x: torch.Tensor, 

175 attn_mask: Optional[torch.Tensor], 

176 key_padding_mask: Optional[torch.Tensor], 

177 is_causal: bool, 

178 ) -> torch.Tensor: 

179 x = self.self_attn( 

180 x, 

181 x, 

182 x, 

183 attn_mask=attn_mask, 

184 key_padding_mask=key_padding_mask, 

185 need_weights=False, 

186 is_causal=is_causal, 

187 )[0] 

188 x = self.dropout1(x) 

189 return x 

190 

191 def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 

192 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 

193 return self.dropout2(x) 

194 

195 

196class TransformerDecoderLayer(nn.Module): 

197 r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 

198 

199 Adapted from Pytorch :py:class:`torch.nn.TransformerDecoderLayer`. 

200 

201 This standard decoder layer is based on the paper **Attention Is All You Need**. 

202 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 

203 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 

204 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 

205 in a different way during application. 

206 

207 Args: 

208 d_model: the number of expected features in the input (required). 

209 nhead: the number of heads in the multiheadattention models (required). 

210 dim_feedforward: the dimension of the feedforward network model (default=2048). 

211 dropout: the dropout value (default=0.1). 

212 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU` 

213 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

214 batch_first: If ``True``, then the input and output tensors are provided 

215 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

216 norm_first: if ``True``, layer norm is done prior to self attention, multihead 

217 attention and feedforward operations, respectively. Otherwise it's done after. 

218 Default: ``False`` (after). 

219 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

220 bias. Default: ``True``. 

221 

222 Examples:: 

223 

224 .. code-block:: python 

225 

226 import torchcvnn as c_nn 

227 import torch 

228 

229 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8) 

230 memory = torch.rand(10, 32, 512, dtype=torch.complex64) 

231 tgt = torch.rand(20, 32, 512, dtype=torch.complex64) 

232 out = decoder_layer(tgt, memory) 

233 

234 Alternatively, when ``batch_first`` is ``True``: 

235 

236 .. code-block:: python 

237 

238 import torchcvnn as c_nn 

239 import torch 

240 

241 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) 

242 memory = torch.rand(32, 10, 512, dtype=torch.complex64) 

243 tgt = torch.rand(32, 20, 512, dtype=torch.complex64) 

244 out = decoder_layer(tgt, memory) 

245 """ 

246 

247 __constants__ = ["norm_first"] 

248 

249 # Adapted from Pytorch TransformerDecoderLayer 

250 # with CReLU instead of ReLU and dtype=torch.complex64 

251 def __init__( 

252 self, 

253 d_model: int, 

254 nhead: int, 

255 dim_feedforward: int = 2048, 

256 dropout: float = 0.1, 

257 activation: nn.Module = CReLU, 

258 layer_norm_eps: float = 1e-5, 

259 batch_first: bool = False, 

260 norm_first: bool = False, 

261 bias: bool = True, 

262 device=None, 

263 dtype: torch.dtype = torch.complex64, 

264 ) -> None: 

265 factory_kwargs = {"device": device, "dtype": dtype} 

266 super().__init__() 

267 self.self_attn = MultiheadAttention( 

268 d_model, 

269 nhead, 

270 dropout=dropout, 

271 batch_first=batch_first, 

272 bias=bias, 

273 **factory_kwargs, 

274 ) 

275 self.multihead_attn = MultiheadAttention( 

276 d_model, 

277 nhead, 

278 dropout=dropout, 

279 batch_first=batch_first, 

280 bias=bias, 

281 **factory_kwargs, 

282 ) 

283 # Implementation of Feedforward model 

284 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 

285 self.dropout = Dropout(dropout) 

286 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 

287 

288 self.norm_first = norm_first 

289 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

290 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

291 self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

292 self.dropout1 = Dropout(dropout) 

293 self.dropout2 = Dropout(dropout) 

294 self.dropout3 = Dropout(dropout) 

295 

296 self.activation = activation() 

297 

298 self._reset_parameters() 

299 

300 def _reset_parameters(self): 

301 complex_xavier_uniform_(self.linear1.weight) 

302 if self.linear1.bias is not None: 

303 nn.init.constant_(self.linear1.bias, 0) 

304 complex_xavier_uniform_(self.linear2.weight) 

305 if self.linear2.bias is not None: 

306 nn.init.constant_(self.linear2.bias, 0) 

307 

308 # Adapted from Pytorch TransformerDecoderLayer 

309 # with CReLU instead of ReLU 

310 def __setstate__(self, state): 

311 if "activation" not in state: 

312 state["activation"] = CReLU() 

313 super().__setstate__(state) 

314 

315 # Same from Pytorch TransformerDecoderLayer 

316 def forward( 

317 self, 

318 tgt: torch.Tensor, 

319 memory: torch.Tensor, 

320 tgt_mask: Optional[torch.Tensor] = None, 

321 memory_mask: Optional[torch.Tensor] = None, 

322 tgt_key_padding_mask: Optional[torch.Tensor] = None, 

323 memory_key_padding_mask: Optional[torch.Tensor] = None, 

324 tgt_is_causal: bool = False, 

325 memory_is_causal: bool = False, 

326 ) -> torch.Tensor: 

327 r"""Pass the inputs (and mask) through the decoder layer. 

328 

329 Args: 

330 tgt: the sequence to the decoder layer (required). 

331 memory: the sequence from the last layer of the encoder (required). 

332 tgt_mask: the mask for the tgt sequence (optional). 

333 memory_mask: the mask for the memory sequence (optional). 

334 tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 

335 memory_key_padding_mask: the mask for the memory keys per batch (optional). 

336 tgt_is_causal: If specified, applies a causal mask as ``tgt mask``. 

337 Default: ``False``. 

338 Warning: 

339 ``tgt_is_causal`` provides a hint that ``tgt_mask`` is 

340 the causal mask. Providing incorrect hints can result in 

341 incorrect execution, including forward and backward 

342 compatibility. 

343 memory_is_causal: If specified, applies a causal mask as 

344 ``memory mask``. 

345 Default: ``False``. 

346 Warning: 

347 ``memory_is_causal`` provides a hint that 

348 ``memory_mask`` is the causal mask. Providing incorrect 

349 hints can result in incorrect execution, including 

350 forward and backward compatibility. 

351 

352 Shape: 

353 see the docs in :class:`~torch.nn.Transformer`. 

354 """ 

355 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 

356 

357 x = tgt 

358 if self.norm_first: 

359 x = x + self._sa_block( 

360 self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal 

361 ) 

362 x = x + self._mha_block( 

363 self.norm2(x), 

364 memory, 

365 memory_mask, 

366 memory_key_padding_mask, 

367 memory_is_causal, 

368 ) 

369 x = x + self._ff_block(self.norm3(x)) 

370 else: 

371 x = self.norm1( 

372 x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal) 

373 ) 

374 x = self.norm2( 

375 x 

376 + self._mha_block( 

377 x, memory, memory_mask, memory_key_padding_mask, memory_is_causal 

378 ) 

379 ) 

380 x = self.norm3(x + self._ff_block(x)) 

381 

382 return x 

383 

384 # self-attention block 

385 # Same from Pytorch TransformerDecoderLayer 

386 def _sa_block( 

387 self, 

388 x: torch.Tensor, 

389 attn_mask: Optional[torch.Tensor], 

390 key_padding_mask: Optional[torch.Tensor], 

391 is_causal: bool = False, 

392 ) -> torch.Tensor: 

393 x = self.self_attn( 

394 x, 

395 x, 

396 x, 

397 attn_mask=attn_mask, 

398 key_padding_mask=key_padding_mask, 

399 is_causal=is_causal, 

400 need_weights=False, 

401 )[0] 

402 return self.dropout1(x) 

403 

404 # multihead attention block 

405 # Same from Pytorch TransformerDecoderLayer 

406 def _mha_block( 

407 self, 

408 x: torch.Tensor, 

409 mem: torch.Tensor, 

410 attn_mask: Optional[torch.Tensor], 

411 key_padding_mask: Optional[torch.Tensor], 

412 is_causal: bool = False, 

413 ) -> torch.Tensor: 

414 x = self.multihead_attn( 

415 x, 

416 mem, 

417 mem, 

418 attn_mask=attn_mask, 

419 key_padding_mask=key_padding_mask, 

420 is_causal=is_causal, 

421 need_weights=False, 

422 )[0] 

423 return self.dropout2(x) 

424 

425 # feed forward block 

426 # Same from Pytorch TransformerDecoderLayer 

427 def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 

428 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 

429 return self.dropout3(x) 

430 

431 

432class Transformer(nn.Module): 

433 r"""A transformer model. 

434 

435 Adapted from :py:class:`torch.nn.Transformer`. 

436 

437 User is able to modify the attributes as needed. The architecture 

438 is based on the paper **Attention Is All You Need**. Ashish Vaswani, Noam Shazeer, 

439 Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 

440 Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 

441 Processing Systems, pages 6000-6010. 

442 

443 

444 The :py:class:`MultiheadAttention` implementation is based on the paper **Building blocks for a complex-valued 

445 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech, 

446 and Signal Processing (ICASSP). 

447 

448 

449 Args: 

450 d_model: the number of expected features in the encoder/decoder inputs (default=512). 

451 nhead: the number of heads in the multiheadattention models (default=8). 

452 num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 

453 num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 

454 dim_feedforward: the dimension of the feedforward network model (default=2048). 

455 dropout: the dropout value (default=0.1). 

456 activation: the activation function of encoder/decoder intermediate layer. Default: :py:class:`CReLU`. 

457 custom_encoder: custom encoder (default=None). 

458 custom_decoder: custom decoder (default=None). 

459 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

460 batch_first: If ``True``, then the input and output tensors are provided 

461 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

462 norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before 

463 other attention and feedforward operations, otherwise after. Default: ``False`` (after). 

464 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

465 bias. Default: ``True``. 

466 

467 Examples: 

468 

469 .. code-block:: python 

470 

471 import torchcvnn as c_nn 

472 import torch 

473 

474 transformer_model = c_nn.Transformer(nhead=16, num_encoder_layers=12) 

475 src = torch.rand((10, 32, 512), dtype=torch.complex64) 

476 tgt = torch.rand((20, 32, 512), dtype=torch.complex64) 

477 out = transformer_model(src, tgt) 

478 

479 """ 

480 

481 def __init__( 

482 self, 

483 d_model: int = 512, 

484 nhead: int = 8, 

485 num_encoder_layers: int = 6, 

486 num_decoder_layers: int = 6, 

487 dim_feedforward: int = 2048, 

488 dropout: float = 0.1, 

489 activation: nn.Module = CReLU, 

490 layer_norm_eps: float = 1e-5, 

491 batch_first: bool = False, 

492 norm_first: bool = False, 

493 bias: bool = True, 

494 device=None, 

495 dtype: torch.dtype = torch.complex64, 

496 ) -> None: 

497 factory_kwargs = {"device": device, "dtype": dtype} 

498 super().__init__() 

499 

500 encoder_layer = TransformerEncoderLayer( 

501 d_model, 

502 nhead, 

503 dim_feedforward, 

504 dropout, 

505 activation, 

506 layer_norm_eps, 

507 batch_first, 

508 norm_first, 

509 bias, 

510 **factory_kwargs, 

511 ) 

512 encoder_norm = LayerNorm( 

513 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 

514 ) 

515 self.encoder = TransformerEncoder( 

516 encoder_layer, num_encoder_layers, encoder_norm 

517 ) 

518 

519 decoder_layer = TransformerDecoderLayer( 

520 d_model, 

521 nhead, 

522 dim_feedforward, 

523 dropout, 

524 activation, 

525 layer_norm_eps, 

526 batch_first, 

527 norm_first, 

528 bias, 

529 **factory_kwargs, 

530 ) 

531 decoder_norm = LayerNorm( 

532 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 

533 ) 

534 self.decoder = TransformerDecoder( 

535 decoder_layer, num_decoder_layers, decoder_norm 

536 ) 

537 

538 self.d_model = d_model 

539 self.nhead = nhead 

540 

541 self.batch_first = batch_first 

542 

543 def forward( 

544 self, 

545 src: torch.Tensor, 

546 tgt: torch.Tensor, 

547 src_mask: Optional[torch.Tensor] = None, 

548 tgt_mask: Optional[torch.Tensor] = None, 

549 memory_mask: Optional[torch.Tensor] = None, 

550 src_key_padding_mask: Optional[torch.Tensor] = None, 

551 tgt_key_padding_mask: Optional[torch.Tensor] = None, 

552 memory_key_padding_mask: Optional[torch.Tensor] = None, 

553 src_is_causal: Optional[bool] = None, 

554 tgt_is_causal: Optional[bool] = None, 

555 memory_is_causal: bool = False, 

556 ) -> torch.Tensor: 

557 memory = self.encoder( 

558 src, 

559 mask=src_mask, 

560 src_key_padding_mask=src_key_padding_mask, 

561 is_causal=src_is_causal, 

562 ) 

563 output = self.decoder( 

564 tgt, 

565 memory, 

566 tgt_mask=tgt_mask, 

567 memory_mask=memory_mask, 

568 tgt_key_padding_mask=tgt_key_padding_mask, 

569 memory_key_padding_mask=memory_key_padding_mask, 

570 tgt_is_causal=tgt_is_causal, 

571 memory_is_causal=memory_is_causal, 

572 ) 

573 return output