Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / transformer.py: 96%
146 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-14 06:48 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Union, Callable, Optional
26# External imports
27import torch
28import torch.nn as nn
29import torch.nn.functional as F
30from torch.nn.modules.transformer import (
31 _get_clones
32)
34# Local imports
35from .activation import CReLU, MultiheadAttention
36from .dropout import Dropout
37from .normalization import LayerNorm
38from .initialization import complex_xavier_uniform_
41class TransformerEncoderLayer(nn.Module):
42 r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
44 This class is adapted from pytorch :py:class:`torch.nn.TransformerEncoderLayer`
46 This standard encoder layer is based on the paper **Attention Is All You Need**.
47 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
48 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
49 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
50 in a different way during application.
52 If you are implementing a custom layer, you may derive it either from
53 the Module or TransformerEncoderLayer class.
55 Args:
56 d_model: the number of expected features in the input (required).
57 nhead: the number of heads in the multiheadattention models (required).
58 dim_feedforward: the dimension of the feedforward network model (default=2048).
59 dropout: the dropout value (default=0.1).
60 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU`
61 layer_norm_eps: the eps value in layer normalization components (default=1e-5).
62 batch_first: If ``True``, then the input and output tensors are provided
63 as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
64 norm_first: if ``True``, layer norm is done prior to attention and feedforward
65 operations, respectively. Otherwise it's done after. Default: ``False`` (after).
66 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
67 bias. Default: ``True``.
69 Examples:
71 .. code-block:: python
73 import torchcvnn as c_nn
74 import torch
76 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8)
77 src = torch.rand(10, 32, 512, dtype=torch.complex64)
78 out = encoder_layer(src)
80 Alternatively, when ``batch_first`` is ``True``:
82 .. code-block:: python
84 import torchcvnn as c_nn
85 import torch
87 encoder_layer = c_nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
88 src = torch.rand(32, 10, 512, dtype=torch.complex64)
89 out = encoder_layer(src)
91 """
93 def __init__(
94 self,
95 d_model: int,
96 nhead: int,
97 dim_feedforward: int = 2048,
98 dropout: float = 0.1,
99 activation: nn.Module = CReLU,
100 layer_norm_eps: float = 1e-5,
101 batch_first: bool = False,
102 norm_first: bool = False,
103 bias: bool = True,
104 device: torch.device = None,
105 dtype: torch.dtype = torch.complex64,
106 attn_module=MultiheadAttention,
107 ) -> None:
108 factory_kwargs = {"device": device, "dtype": dtype}
109 super().__init__()
110 self.self_attn = attn_module(
111 d_model,
112 nhead,
113 dropout=dropout,
114 bias=bias,
115 batch_first=batch_first,
116 **factory_kwargs,
117 )
118 # Implementation of Feedforward model
119 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
120 self.dropout = Dropout(dropout)
121 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
123 self.norm_first = norm_first
124 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
125 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
126 self.dropout1 = Dropout(dropout)
127 self.dropout2 = Dropout(dropout)
129 self.activation = activation()
131 self._reset_parameters()
133 def _reset_parameters(self):
134 complex_xavier_uniform_(self.linear1.weight)
135 if self.linear1.bias is not None:
136 nn.init.constant_(self.linear1.bias, 0)
137 complex_xavier_uniform_(self.linear2.weight)
138 if self.linear2.bias is not None:
139 nn.init.constant_(self.linear2.bias, 0)
141 def __setstate__(self, state):
142 super().__setstate__(state)
143 if not hasattr(self, "activation"):
144 self.activation = CReLU()
146 def forward(
147 self,
148 src: torch.Tensor,
149 ) -> torch.Tensor:
151 x = src
152 if self.norm_first:
153 x = x + self._sa_block(
154 self.norm1(x)
155 )
156 x = x + self._ff_block(self.norm2(x))
157 else:
158 x = x + self._sa_block(x)
159 x = self.norm1(x)
160 x = x + self._ff_block(x)
161 x = self.norm2(x)
163 return x
165 def _sa_block(
166 self,
167 x: torch.Tensor,
168 ) -> torch.Tensor:
169 x = self.self_attn(
170 x,
171 x,
172 x,
173 need_weights=False
174 )
175 x = self.dropout1(x)
176 return x
178 def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
179 x = self.linear2(self.dropout(self.activation(self.linear1(x))))
180 return self.dropout2(x)
183class TransformerEncoder(nn.Module):
184 r"""TransformerEncoder is a stack of N encoder layers.
186 This class is adapted from pytorch :py:class:`torch.nn.TransformerEncoder`
188 This TransformerEncoder layer implements the original architecture described
189 in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
190 intent of this layer is as a reference implementation for foundational understanding
191 and thus it contains only limited features relative to newer Transformer architectures.
192 Given the fast pace of innovation in transformer-like architectures, we recommend
193 exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
194 to build efficient layers from building blocks in core or using higher
195 level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
197 .. warning::
198 All layers in the TransformerEncoder are initialized with the same parameters.
199 It is recommended to manually initialize the layers after creating the TransformerEncoder instance.
201 Args:
202 encoder_layer: an instance of the TransformerEncoderLayer() class (required).
203 num_layers: the number of sub-encoder-layers in the encoder (required).
204 norm: the layer normalization component (optional).
205 enable_nested_tensor: if True, input will automatically convert to nested tensor
206 (and convert back on output). This will improve the overall performance of
207 TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
209 Examples:
210 >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
211 >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
212 >>> src = torch.rand(10, 32, 512)
213 >>> out = transformer_encoder(src)
214 """
216 __constants__ = ["norm"]
218 def __init__(
219 self,
220 encoder_layer: "TransformerEncoderLayer",
221 num_layers: int,
222 norm: nn.Module | None = None,
223 ) -> None:
224 super().__init__()
225 self.layers = _get_clones(encoder_layer, num_layers)
226 self.num_layers = num_layers
227 self.norm = norm
229 def forward(
230 self,
231 src: torch.Tensor,
232 ) -> torch.Tensor:
233 r"""Pass the input through the encoder layers in turn.
235 Args:
236 src: the sequence to the encoder (required).
238 Shape:
239 see the docs in :class:`~torch.nn.Transformer`.
240 """
241 output = src
243 for mod in self.layers:
244 output = mod(
245 output,
246 )
248 if self.norm is not None:
249 output = self.norm(output)
251 return output
255class TransformerDecoderLayer(nn.Module):
256 r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
258 Adapted from Pytorch :py:class:`torch.nn.TransformerDecoderLayer`.
260 This standard decoder layer is based on the paper **Attention Is All You Need**.
261 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
262 Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
263 Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
264 in a different way during application.
266 Args:
267 d_model: the number of expected features in the input (required).
268 nhead: the number of heads in the multiheadattention models (required).
269 dim_feedforward: the dimension of the feedforward network model (default=2048).
270 dropout: the dropout value (default=0.1).
271 activation: the activation function of the intermediate layer. Default: :py:class:`CReLU`
272 layer_norm_eps: the eps value in layer normalization components (default=1e-5).
273 batch_first: If ``True``, then the input and output tensors are provided
274 as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
275 norm_first: if ``True``, layer norm is done prior to self attention, multihead
276 attention and feedforward operations, respectively. Otherwise it's done after.
277 Default: ``False`` (after).
278 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
279 bias. Default: ``True``.
281 Examples::
283 .. code-block:: python
285 import torchcvnn as c_nn
286 import torch
288 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8)
289 memory = torch.rand(10, 32, 512, dtype=torch.complex64)
290 tgt = torch.rand(20, 32, 512, dtype=torch.complex64)
291 out = decoder_layer(tgt, memory)
293 Alternatively, when ``batch_first`` is ``True``:
295 .. code-block:: python
297 import torchcvnn as c_nn
298 import torch
300 decoder_layer = c_nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
301 memory = torch.rand(32, 10, 512, dtype=torch.complex64)
302 tgt = torch.rand(32, 20, 512, dtype=torch.complex64)
303 out = decoder_layer(tgt, memory)
304 """
306 __constants__ = ["norm_first"]
308 # Adapted from Pytorch TransformerDecoderLayer
309 # with CReLU instead of ReLU and dtype=torch.complex64
310 def __init__(
311 self,
312 d_model: int,
313 nhead: int,
314 dim_feedforward: int = 2048,
315 dropout: float = 0.1,
316 activation: nn.Module = CReLU,
317 layer_norm_eps: float = 1e-5,
318 batch_first: bool = False,
319 norm_first: bool = False,
320 bias: bool = True,
321 device=None,
322 dtype: torch.dtype = torch.complex64,
323 ) -> None:
324 factory_kwargs = {"device": device, "dtype": dtype}
325 super().__init__()
326 self.self_attn = MultiheadAttention(
327 d_model,
328 nhead,
329 dropout=dropout,
330 batch_first=batch_first,
331 bias=bias,
332 **factory_kwargs,
333 )
334 self.multihead_attn = MultiheadAttention(
335 d_model,
336 nhead,
337 dropout=dropout,
338 batch_first=batch_first,
339 bias=bias,
340 **factory_kwargs,
341 )
342 # Implementation of Feedforward model
343 self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
344 self.dropout = Dropout(dropout)
345 self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
347 self.norm_first = norm_first
348 self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
349 self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
350 self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
351 self.dropout1 = Dropout(dropout)
352 self.dropout2 = Dropout(dropout)
353 self.dropout3 = Dropout(dropout)
355 self.activation = activation()
357 self._reset_parameters()
359 def _reset_parameters(self):
360 complex_xavier_uniform_(self.linear1.weight)
361 if self.linear1.bias is not None:
362 nn.init.constant_(self.linear1.bias, 0)
363 complex_xavier_uniform_(self.linear2.weight)
364 if self.linear2.bias is not None:
365 nn.init.constant_(self.linear2.bias, 0)
367 # Adapted from Pytorch TransformerDecoderLayer
368 # with CReLU instead of ReLU
369 def __setstate__(self, state):
370 if "activation" not in state:
371 state["activation"] = CReLU()
372 super().__setstate__(state)
374 # Same from Pytorch TransformerDecoderLayer
375 def forward(
376 self,
377 tgt: torch.Tensor,
378 memory: torch.Tensor,
379 ) -> torch.Tensor:
380 r"""Pass the inputs through the decoder layer.
382 Args:
383 tgt: the sequence to the decoder layer (required).
384 memory: the sequence from the last layer of the encoder (required).
386 Shape:
387 see the docs in :class:`~torch.nn.Transformer`.
388 """
389 # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
391 x = tgt
392 if self.norm_first:
393 x = x + self._sa_block(
394 self.norm1(x)
395 )
396 x = x + self._mha_block(
397 self.norm2(x),
398 memory,
399 )
400 x = x + self._ff_block(self.norm3(x))
401 else:
402 x = self.norm1(
403 x + self._sa_block(x)
404 )
405 x = self.norm2(
406 x
407 + self._mha_block(
408 x, memory
409 )
410 )
411 x = self.norm3(x + self._ff_block(x))
413 return x
415 # self-attention block
416 # Same from Pytorch TransformerDecoderLayer
417 def _sa_block(
418 self,
419 x: torch.Tensor,
420 ) -> torch.Tensor:
421 x = self.self_attn(
422 x,
423 x,
424 x,
425 need_weights=False,
426 )[0]
427 return self.dropout1(x)
429 # multihead attention block
430 # Same from Pytorch TransformerDecoderLayer
431 def _mha_block(
432 self,
433 x: torch.Tensor,
434 mem: torch.Tensor,
435 ) -> torch.Tensor:
436 x = self.multihead_attn(
437 x,
438 mem,
439 mem,
440 need_weights=False,
441 )[0]
442 return self.dropout2(x)
444 # feed forward block
445 # Same from Pytorch TransformerDecoderLayer
446 def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
447 x = self.linear2(self.dropout(self.activation(self.linear1(x))))
448 return self.dropout3(x)
450class TransformerDecoder(nn.Module):
451 r"""TransformerDecoder is a stack of N decoder layers.
453 Adapted from :py:class:`torch.nn.TransformerDecodder`.
455 This TransformerDecoder layer implements the original architecture described
456 in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
457 intent of this layer is as a reference implementation for foundational understanding
458 and thus it contains only limited features relative to newer Transformer architectures.
459 Given the fast pace of innovation in transformer-like architectures, we recommend
460 exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
461 to build efficient layers from building blocks in core or using higher
462 level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
464 .. warning::
465 All layers in the TransformerDecoder are initialized with the same parameters.
466 It is recommended to manually initialize the layers after creating the TransformerDecoder instance.
468 Args:
469 decoder_layer: an instance of the TransformerDecoderLayer() class (required).
470 num_layers: the number of sub-decoder-layers in the decoder (required).
471 norm: the layer normalization component (optional).
473 Examples:
474 >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
475 >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
476 >>> memory = torch.rand(10, 32, 512)
477 >>> tgt = torch.rand(20, 32, 512)
478 >>> out = transformer_decoder(tgt, memory)
479 """
481 __constants__ = ["norm"]
483 def __init__(
484 self,
485 decoder_layer: "TransformerDecoderLayer",
486 num_layers: int,
487 norm: nn.Module | None = None,
488 ) -> None:
489 super().__init__()
490 self.layers = _get_clones(decoder_layer, num_layers)
491 self.num_layers = num_layers
492 self.norm = norm
494 def forward(
495 self,
496 tgt: torch.Tensor,
497 memory: torch.Tensor,
498 ) -> torch.Tensor:
499 r"""Pass the inputs (and mask) through the decoder layer in turn.
501 Args:
502 tgt: the sequence to the decoder (required).
503 memory: the sequence from the last layer of the encoder (required).
505 Shape:
506 see the docs in :class:`~torch.nn.Transformer`.
507 """
508 output = tgt
510 for mod in self.layers:
511 output = mod(
512 output,
513 memory
514 )
516 if self.norm is not None:
517 output = self.norm(output)
519 return output
521class Transformer(nn.Module):
522 r"""A transformer model.
524 Adapted from :py:class:`torch.nn.Transformer`.
526 User is able to modify the attributes as needed. The architecture
527 is based on the paper **Attention Is All You Need**. Ashish Vaswani, Noam Shazeer,
528 Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
529 Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
530 Processing Systems, pages 6000-6010.
533 The :py:class:`MultiheadAttention` implementation is based on the paper **Building blocks for a complex-valued
534 transformer architecture**. Florian Eilers, Xiaoyi Jiang. 2023. In International Conference on Acoustics, Speech,
535 and Signal Processing (ICASSP).
538 Args:
539 d_model: the number of expected features in the encoder/decoder inputs (default=512).
540 nhead: the number of heads in the multiheadattention models (default=8).
541 num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
542 num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
543 dim_feedforward: the dimension of the feedforward network model (default=2048).
544 dropout: the dropout value (default=0.1).
545 activation: the activation function of encoder/decoder intermediate layer. Default: :py:class:`CReLU`.
546 custom_encoder: custom encoder (default=None).
547 custom_decoder: custom decoder (default=None).
548 layer_norm_eps: the eps value in layer normalization components (default=1e-5).
549 batch_first: If ``True``, then the input and output tensors are provided
550 as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
551 norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
552 other attention and feedforward operations, otherwise after. Default: ``False`` (after).
553 bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
554 bias. Default: ``True``.
556 Examples:
558 .. code-block:: python
560 import torchcvnn as c_nn
561 import torch
563 transformer_model = c_nn.Transformer(nhead=16, num_encoder_layers=12)
564 src = torch.rand((10, 32, 512), dtype=torch.complex64)
565 tgt = torch.rand((20, 32, 512), dtype=torch.complex64)
566 out = transformer_model(src, tgt)
568 """
570 def __init__(
571 self,
572 d_model: int = 512,
573 nhead: int = 8,
574 num_encoder_layers: int = 6,
575 num_decoder_layers: int = 6,
576 dim_feedforward: int = 2048,
577 dropout: float = 0.1,
578 activation: nn.Module = CReLU,
579 layer_norm_eps: float = 1e-5,
580 batch_first: bool = False,
581 norm_first: bool = False,
582 bias: bool = True,
583 device=None,
584 dtype: torch.dtype = torch.complex64,
585 ) -> None:
586 factory_kwargs = {"device": device, "dtype": dtype}
587 super().__init__()
589 encoder_layer = TransformerEncoderLayer(
590 d_model,
591 nhead,
592 dim_feedforward,
593 dropout,
594 activation,
595 layer_norm_eps,
596 batch_first,
597 norm_first,
598 bias,
599 **factory_kwargs,
600 )
601 encoder_norm = LayerNorm(
602 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
603 )
604 self.encoder = TransformerEncoder(
605 encoder_layer, num_encoder_layers, encoder_norm
606 )
608 decoder_layer = TransformerDecoderLayer(
609 d_model,
610 nhead,
611 dim_feedforward,
612 dropout,
613 activation,
614 layer_norm_eps,
615 batch_first,
616 norm_first,
617 bias,
618 **factory_kwargs,
619 )
620 decoder_norm = LayerNorm(
621 d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
622 )
623 self.decoder = TransformerDecoder(
624 decoder_layer, num_decoder_layers, decoder_norm
625 )
627 self.d_model = d_model
628 self.nhead = nhead
630 self.batch_first = batch_first
632 def forward(
633 self,
634 src: torch.Tensor,
635 tgt: torch.Tensor,
636 ) -> torch.Tensor:
637 memory = self.encoder(
638 src
639 )
640 output = self.decoder(
641 tgt,
642 memory,
643 )
644 return output