Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / transformer.py: 96%

146 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-14 06:48 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Union, Callable, Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.functional as F 

30from torch.nn.modules.transformer import ( 

31 _get_clones 

32) 

33 

34# Local imports 

35from .activation import CReLU, MultiheadAttention 

36from .dropout import Dropout 

37from .normalization import LayerNorm 

38from .initialization import complex_xavier_uniform_ 

39 

40 

41class TransformerEncoderLayer(nn.Module): 

42 r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 

43 

44 This class is adapted from pytorch :py:class:`torch.nn.TransformerEncoderLayer` 

45 

46 This standard encoder layer is based on the paper **Attention Is All You Need**. 

47 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 

48 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 

49 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 

50 in a different way during application. 

51 

52 If you are implementing a custom layer, you may derive it either from 

53 the Module or TransformerEncoderLayer class. 

54 

55 Args: 

56 d_model: the number of expected features in the input (required). 

57 nhead: the number of heads in the multiheadattention models (required). 

58 dim_feedforward: the dimension of the feedforward network model (default=2048). 

59 dropout: the dropout value (default=0.1). 

60 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU` 

61 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

62 batch_first: If ``True``, then the input and output tensors are provided 

63 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

64 norm_first: if ``True``, layer norm is done prior to attention and feedforward 

65 operations, respectively. Otherwise it's done after. Default: ``False`` (after). 

66 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

67 bias. Default: ``True``. 

68 

69 Examples: 

70 

71 .. code-block:: python 

72 

73 import torchcvnn as c_nn 

74 import torch 

75 

76 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8) 

77 src = torch.rand(10, 32, 512, dtype=torch.complex64) 

78 out = encoder_layer(src) 

79 

80 Alternatively, when ``batch_first`` is ``True``: 

81 

82 .. code-block:: python 

83 

84 import torchcvnn as c_nn 

85 import torch 

86 

87 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) 

88 src = torch.rand(32, 10, 512, dtype=torch.complex64) 

89 out = encoder_layer(src) 

90 

91 """ 

92 

93 def __init__( 

94 self, 

95 d_model: int, 

96 nhead: int, 

97 dim_feedforward: int = 2048, 

98 dropout: float = 0.1, 

99 activation: nn.Module = CReLU, 

100 layer_norm_eps: float = 1e-5, 

101 batch_first: bool = False, 

102 norm_first: bool = False, 

103 bias: bool = True, 

104 device: torch.device = None, 

105 dtype: torch.dtype = torch.complex64, 

106 attn_module=MultiheadAttention, 

107 ) -> None: 

108 factory_kwargs = {"device": device, "dtype": dtype} 

109 super().__init__() 

110 self.self_attn = attn_module( 

111 d_model, 

112 nhead, 

113 dropout=dropout, 

114 bias=bias, 

115 batch_first=batch_first, 

116 **factory_kwargs, 

117 ) 

118 # Implementation of Feedforward model 

119 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 

120 self.dropout = Dropout(dropout) 

121 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 

122 

123 self.norm_first = norm_first 

124 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

125 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

126 self.dropout1 = Dropout(dropout) 

127 self.dropout2 = Dropout(dropout) 

128 

129 self.activation = activation() 

130 

131 self._reset_parameters() 

132 

133 def _reset_parameters(self): 

134 complex_xavier_uniform_(self.linear1.weight) 

135 if self.linear1.bias is not None: 

136 nn.init.constant_(self.linear1.bias, 0) 

137 complex_xavier_uniform_(self.linear2.weight) 

138 if self.linear2.bias is not None: 

139 nn.init.constant_(self.linear2.bias, 0) 

140 

141 def __setstate__(self, state): 

142 super().__setstate__(state) 

143 if not hasattr(self, "activation"): 

144 self.activation = CReLU() 

145 

146 def forward( 

147 self, 

148 src: torch.Tensor, 

149 ) -> torch.Tensor: 

150 

151 x = src 

152 if self.norm_first: 

153 x = x + self._sa_block( 

154 self.norm1(x) 

155 ) 

156 x = x + self._ff_block(self.norm2(x)) 

157 else: 

158 x = x + self._sa_block(x) 

159 x = self.norm1(x) 

160 x = x + self._ff_block(x) 

161 x = self.norm2(x) 

162 

163 return x 

164 

165 def _sa_block( 

166 self, 

167 x: torch.Tensor, 

168 ) -> torch.Tensor: 

169 x = self.self_attn( 

170 x, 

171 x, 

172 x, 

173 need_weights=False 

174 ) 

175 x = self.dropout1(x) 

176 return x 

177 

178 def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 

179 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 

180 return self.dropout2(x) 

181 

182 

183class TransformerEncoder(nn.Module): 

184 r"""TransformerEncoder is a stack of N encoder layers. 

185 

186 This class is adapted from pytorch :py:class:`torch.nn.TransformerEncoder` 

187 

188 This TransformerEncoder layer implements the original architecture described 

189 in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The 

190 intent of this layer is as a reference implementation for foundational understanding 

191 and thus it contains only limited features relative to newer Transformer architectures. 

192 Given the fast pace of innovation in transformer-like architectures, we recommend 

193 exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_ 

194 to build efficient layers from building blocks in core or using higher 

195 level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_. 

196 

197 .. warning:: 

198 All layers in the TransformerEncoder are initialized with the same parameters. 

199 It is recommended to manually initialize the layers after creating the TransformerEncoder instance. 

200 

201 Args: 

202 encoder_layer: an instance of the TransformerEncoderLayer() class (required). 

203 num_layers: the number of sub-encoder-layers in the encoder (required). 

204 norm: the layer normalization component (optional). 

205 enable_nested_tensor: if True, input will automatically convert to nested tensor 

206 (and convert back on output). This will improve the overall performance of 

207 TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 

208 

209 Examples: 

210 >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 

211 >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 

212 >>> src = torch.rand(10, 32, 512) 

213 >>> out = transformer_encoder(src) 

214 """ 

215 

216 __constants__ = ["norm"] 

217 

218 def __init__( 

219 self, 

220 encoder_layer: "TransformerEncoderLayer", 

221 num_layers: int, 

222 norm: nn.Module | None = None, 

223 ) -> None: 

224 super().__init__() 

225 self.layers = _get_clones(encoder_layer, num_layers) 

226 self.num_layers = num_layers 

227 self.norm = norm 

228 

229 def forward( 

230 self, 

231 src: torch.Tensor, 

232 ) -> torch.Tensor: 

233 r"""Pass the input through the encoder layers in turn. 

234 

235 Args: 

236 src: the sequence to the encoder (required). 

237 

238 Shape: 

239 see the docs in :class:`~torch.nn.Transformer`. 

240 """ 

241 output = src 

242 

243 for mod in self.layers: 

244 output = mod( 

245 output, 

246 ) 

247 

248 if self.norm is not None: 

249 output = self.norm(output) 

250 

251 return output 

252 

253 

254 

255class TransformerDecoderLayer(nn.Module): 

256 r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 

257 

258 Adapted from Pytorch :py:class:`torch.nn.TransformerDecoderLayer`. 

259 

260 This standard decoder layer is based on the paper **Attention Is All You Need**. 

261 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 

262 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 

263 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 

264 in a different way during application. 

265 

266 Args: 

267 d_model: the number of expected features in the input (required). 

268 nhead: the number of heads in the multiheadattention models (required). 

269 dim_feedforward: the dimension of the feedforward network model (default=2048). 

270 dropout: the dropout value (default=0.1). 

271 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU` 

272 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

273 batch_first: If ``True``, then the input and output tensors are provided 

274 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

275 norm_first: if ``True``, layer norm is done prior to self attention, multihead 

276 attention and feedforward operations, respectively. Otherwise it's done after. 

277 Default: ``False`` (after). 

278 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

279 bias. Default: ``True``. 

280 

281 Examples:: 

282 

283 .. code-block:: python 

284 

285 import torchcvnn as c_nn 

286 import torch 

287 

288 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8) 

289 memory = torch.rand(10, 32, 512, dtype=torch.complex64) 

290 tgt = torch.rand(20, 32, 512, dtype=torch.complex64) 

291 out = decoder_layer(tgt, memory) 

292 

293 Alternatively, when ``batch_first`` is ``True``: 

294 

295 .. code-block:: python 

296 

297 import torchcvnn as c_nn 

298 import torch 

299 

300 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) 

301 memory = torch.rand(32, 10, 512, dtype=torch.complex64) 

302 tgt = torch.rand(32, 20, 512, dtype=torch.complex64) 

303 out = decoder_layer(tgt, memory) 

304 """ 

305 

306 __constants__ = ["norm_first"] 

307 

308 # Adapted from Pytorch TransformerDecoderLayer 

309 # with CReLU instead of ReLU and dtype=torch.complex64 

310 def __init__( 

311 self, 

312 d_model: int, 

313 nhead: int, 

314 dim_feedforward: int = 2048, 

315 dropout: float = 0.1, 

316 activation: nn.Module = CReLU, 

317 layer_norm_eps: float = 1e-5, 

318 batch_first: bool = False, 

319 norm_first: bool = False, 

320 bias: bool = True, 

321 device=None, 

322 dtype: torch.dtype = torch.complex64, 

323 ) -> None: 

324 factory_kwargs = {"device": device, "dtype": dtype} 

325 super().__init__() 

326 self.self_attn = MultiheadAttention( 

327 d_model, 

328 nhead, 

329 dropout=dropout, 

330 batch_first=batch_first, 

331 bias=bias, 

332 **factory_kwargs, 

333 ) 

334 self.multihead_attn = MultiheadAttention( 

335 d_model, 

336 nhead, 

337 dropout=dropout, 

338 batch_first=batch_first, 

339 bias=bias, 

340 **factory_kwargs, 

341 ) 

342 # Implementation of Feedforward model 

343 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs) 

344 self.dropout = Dropout(dropout) 

345 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) 

346 

347 self.norm_first = norm_first 

348 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

349 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

350 self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) 

351 self.dropout1 = Dropout(dropout) 

352 self.dropout2 = Dropout(dropout) 

353 self.dropout3 = Dropout(dropout) 

354 

355 self.activation = activation() 

356 

357 self._reset_parameters() 

358 

359 def _reset_parameters(self): 

360 complex_xavier_uniform_(self.linear1.weight) 

361 if self.linear1.bias is not None: 

362 nn.init.constant_(self.linear1.bias, 0) 

363 complex_xavier_uniform_(self.linear2.weight) 

364 if self.linear2.bias is not None: 

365 nn.init.constant_(self.linear2.bias, 0) 

366 

367 # Adapted from Pytorch TransformerDecoderLayer 

368 # with CReLU instead of ReLU 

369 def __setstate__(self, state): 

370 if "activation" not in state: 

371 state["activation"] = CReLU() 

372 super().__setstate__(state) 

373 

374 # Same from Pytorch TransformerDecoderLayer 

375 def forward( 

376 self, 

377 tgt: torch.Tensor, 

378 memory: torch.Tensor, 

379 ) -> torch.Tensor: 

380 r"""Pass the inputs through the decoder layer. 

381 

382 Args: 

383 tgt: the sequence to the decoder layer (required). 

384 memory: the sequence from the last layer of the encoder (required). 

385 

386 Shape: 

387 see the docs in :class:`~torch.nn.Transformer`. 

388 """ 

389 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf 

390 

391 x = tgt 

392 if self.norm_first: 

393 x = x + self._sa_block( 

394 self.norm1(x) 

395 ) 

396 x = x + self._mha_block( 

397 self.norm2(x), 

398 memory, 

399 ) 

400 x = x + self._ff_block(self.norm3(x)) 

401 else: 

402 x = self.norm1( 

403 x + self._sa_block(x) 

404 ) 

405 x = self.norm2( 

406 x 

407 + self._mha_block( 

408 x, memory 

409 ) 

410 ) 

411 x = self.norm3(x + self._ff_block(x)) 

412 

413 return x 

414 

415 # self-attention block 

416 # Same from Pytorch TransformerDecoderLayer 

417 def _sa_block( 

418 self, 

419 x: torch.Tensor, 

420 ) -> torch.Tensor: 

421 x = self.self_attn( 

422 x, 

423 x, 

424 x, 

425 need_weights=False, 

426 )[0] 

427 return self.dropout1(x) 

428 

429 # multihead attention block 

430 # Same from Pytorch TransformerDecoderLayer 

431 def _mha_block( 

432 self, 

433 x: torch.Tensor, 

434 mem: torch.Tensor, 

435 ) -> torch.Tensor: 

436 x = self.multihead_attn( 

437 x, 

438 mem, 

439 mem, 

440 need_weights=False, 

441 )[0] 

442 return self.dropout2(x) 

443 

444 # feed forward block 

445 # Same from Pytorch TransformerDecoderLayer 

446 def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 

447 x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 

448 return self.dropout3(x) 

449 

450class TransformerDecoder(nn.Module): 

451 r"""TransformerDecoder is a stack of N decoder layers. 

452 

453 Adapted from :py:class:`torch.nn.TransformerDecodder`. 

454 

455 This TransformerDecoder layer implements the original architecture described 

456 in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The 

457 intent of this layer is as a reference implementation for foundational understanding 

458 and thus it contains only limited features relative to newer Transformer architectures. 

459 Given the fast pace of innovation in transformer-like architectures, we recommend 

460 exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_ 

461 to build efficient layers from building blocks in core or using higher 

462 level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_. 

463 

464 .. warning:: 

465 All layers in the TransformerDecoder are initialized with the same parameters. 

466 It is recommended to manually initialize the layers after creating the TransformerDecoder instance. 

467 

468 Args: 

469 decoder_layer: an instance of the TransformerDecoderLayer() class (required). 

470 num_layers: the number of sub-decoder-layers in the decoder (required). 

471 norm: the layer normalization component (optional). 

472 

473 Examples: 

474 >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 

475 >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 

476 >>> memory = torch.rand(10, 32, 512) 

477 >>> tgt = torch.rand(20, 32, 512) 

478 >>> out = transformer_decoder(tgt, memory) 

479 """ 

480 

481 __constants__ = ["norm"] 

482 

483 def __init__( 

484 self, 

485 decoder_layer: "TransformerDecoderLayer", 

486 num_layers: int, 

487 norm: nn.Module | None = None, 

488 ) -> None: 

489 super().__init__() 

490 self.layers = _get_clones(decoder_layer, num_layers) 

491 self.num_layers = num_layers 

492 self.norm = norm 

493 

494 def forward( 

495 self, 

496 tgt: torch.Tensor, 

497 memory: torch.Tensor, 

498 ) -> torch.Tensor: 

499 r"""Pass the inputs (and mask) through the decoder layer in turn. 

500 

501 Args: 

502 tgt: the sequence to the decoder (required). 

503 memory: the sequence from the last layer of the encoder (required). 

504 

505 Shape: 

506 see the docs in :class:`~torch.nn.Transformer`. 

507 """ 

508 output = tgt 

509 

510 for mod in self.layers: 

511 output = mod( 

512 output, 

513 memory 

514 ) 

515 

516 if self.norm is not None: 

517 output = self.norm(output) 

518 

519 return output 

520 

521class Transformer(nn.Module): 

522 r"""A transformer model. 

523 

524 Adapted from :py:class:`torch.nn.Transformer`. 

525 

526 User is able to modify the attributes as needed. The architecture 

527 is based on the paper **Attention Is All You Need**. Ashish Vaswani, Noam Shazeer, 

528 Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 

529 Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 

530 Processing Systems, pages 6000-6010. 

531 

532 

533 The :py:class:`MultiheadAttention` implementation is based on the paper **Building blocks for a complex-valued 

534 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech, 

535 and Signal Processing (ICASSP). 

536 

537 

538 Args: 

539 d_model: the number of expected features in the encoder/decoder inputs (default=512). 

540 nhead: the number of heads in the multiheadattention models (default=8). 

541 num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 

542 num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 

543 dim_feedforward: the dimension of the feedforward network model (default=2048). 

544 dropout: the dropout value (default=0.1). 

545 activation: the activation function of encoder/decoder intermediate layer. Default: :py:class:`CReLU`. 

546 custom_encoder: custom encoder (default=None). 

547 custom_decoder: custom decoder (default=None). 

548 layer_norm_eps: the eps value in layer normalization components (default=1e-5). 

549 batch_first: If ``True``, then the input and output tensors are provided 

550 as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 

551 norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before 

552 other attention and feedforward operations, otherwise after. Default: ``False`` (after). 

553 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive 

554 bias. Default: ``True``. 

555 

556 Examples: 

557 

558 .. code-block:: python 

559 

560 import torchcvnn as c_nn 

561 import torch 

562 

563 transformer_model = c_nn.Transformer(nhead=16, num_encoder_layers=12) 

564 src = torch.rand((10, 32, 512), dtype=torch.complex64) 

565 tgt = torch.rand((20, 32, 512), dtype=torch.complex64) 

566 out = transformer_model(src, tgt) 

567 

568 """ 

569 

570 def __init__( 

571 self, 

572 d_model: int = 512, 

573 nhead: int = 8, 

574 num_encoder_layers: int = 6, 

575 num_decoder_layers: int = 6, 

576 dim_feedforward: int = 2048, 

577 dropout: float = 0.1, 

578 activation: nn.Module = CReLU, 

579 layer_norm_eps: float = 1e-5, 

580 batch_first: bool = False, 

581 norm_first: bool = False, 

582 bias: bool = True, 

583 device=None, 

584 dtype: torch.dtype = torch.complex64, 

585 ) -> None: 

586 factory_kwargs = {"device": device, "dtype": dtype} 

587 super().__init__() 

588 

589 encoder_layer = TransformerEncoderLayer( 

590 d_model, 

591 nhead, 

592 dim_feedforward, 

593 dropout, 

594 activation, 

595 layer_norm_eps, 

596 batch_first, 

597 norm_first, 

598 bias, 

599 **factory_kwargs, 

600 ) 

601 encoder_norm = LayerNorm( 

602 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 

603 ) 

604 self.encoder = TransformerEncoder( 

605 encoder_layer, num_encoder_layers, encoder_norm 

606 ) 

607 

608 decoder_layer = TransformerDecoderLayer( 

609 d_model, 

610 nhead, 

611 dim_feedforward, 

612 dropout, 

613 activation, 

614 layer_norm_eps, 

615 batch_first, 

616 norm_first, 

617 bias, 

618 **factory_kwargs, 

619 ) 

620 decoder_norm = LayerNorm( 

621 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs 

622 ) 

623 self.decoder = TransformerDecoder( 

624 decoder_layer, num_decoder_layers, decoder_norm 

625 ) 

626 

627 self.d_model = d_model 

628 self.nhead = nhead 

629 

630 self.batch_first = batch_first 

631 

632 def forward( 

633 self, 

634 src: torch.Tensor, 

635 tgt: torch.Tensor, 

636 ) -> torch.Tensor: 

637 memory = self.encoder( 

638 src 

639 ) 

640 output = self.decoder( 

641 tgt, 

642 memory, 

643 ) 

644 return output