Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/initialization.py: 93%
29 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2024 Quentin Gabot
5# The following is an adaptation of the initialization code from
6# PyTorch library for the complex-valued neural networks
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files (the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions:
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24# SOFTWARE.
26# Standard imports
27import math
28import warnings
30# External imports
31import torch
32import torch.nn as nn
35def complex_kaiming_normal_(
36 tensor: torch.Tensor,
37 a: float = 0,
38 mode: str = "fan_in",
39 nonlinearity: str = "leaky_relu",
40) -> torch.Tensor:
41 r"""Fills the input `Tensor` with values according to the method
42 described in `Delving deep into rectifiers: Surpassing human-level
43 performance on ImageNet classification` - He, K. et al. (2015), using a
44 normal distribution. The resulting tensor will have values sampled from
45 :math:`\mathcal{N}(0, \text{std}^2)` where
47 .. math::
48 \text{std} = \frac{\text{gain}}{\sqrt{2\text{fan_mode}}}
50 Also known as He initialization.
52 Arguments:
53 tensor: an n-dimensional `torch.Tensor`
54 a: the negative slope of the rectifier used after this layer (only
55 used with ``'leaky_relu'``)
56 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
57 preserves the magnitude of the variance of the weights in the
58 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
59 backwards pass.
60 nonlinearity: the non-linear function (`nn.functional` name),
61 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
63 Examples:
64 >>> w = torch.empty(3, 5, dtype=torch.complex64)
65 >>> c_nn.init.complex_kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
67 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.kaiming_normal_` function
68 """
70 fan = nn.init._calculate_correct_fan(tensor, mode)
71 gain = nn.init.calculate_gain(nonlinearity, a)
72 std = (gain / math.sqrt(fan)) / math.sqrt(2)
74 return nn.init._no_grad_normal_(tensor, 0.0, std)
77def complex_kaiming_uniform_(
78 tensor: torch.Tensor,
79 a: float = 0,
80 mode: str = "fan_in",
81 nonlinearity: str = "leaky_relu",
82) -> torch.Tensor:
83 r"""Fills the input `Tensor` with values according to the method
84 described in `Delving deep into rectifiers: Surpassing human-level
85 performance on ImageNet classification` - He, K. et al. (2015), using a
86 uniform distribution. The resulting tensor will have values sampled from
87 :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
89 .. math::
91 \text{bound} = \text{gain} \times \sqrt{\frac{3}{2\text{fan_mode}}}
93 Also known as He initialization.
95 Arguments:
96 tensor: an n-dimensional :external:py:class`torch.Tensor`
97 a: the negative slope of the rectifier used after this layer (only
98 used with ``'leaky_relu'``)
99 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
100 preserves the magnitude of the variance of the weights in the
101 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
102 backwards pass.
103 nonlinearity: the non-linear function (`nn.functional` name),
104 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
106 Examples:
107 >>> w = torch.empty(3, 5, dtype=torch.complex64)
108 >>> c_nn.init.complex_kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
110 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.kaiming_uniform_` function
111 """
113 if 0 in tensor.shape:
114 warnings.warn("Initializing zero-element tensors is a no-op")
115 return tensor
116 fan = nn.init._calculate_correct_fan(tensor, mode)
117 gain = nn.init.calculate_gain(nonlinearity, a)
118 std = gain / math.sqrt(fan) / math.sqrt(2)
119 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
121 return nn.init._no_grad_uniform_(tensor, -bound, bound)
124def complex_xavier_uniform_(
125 tensor: torch.Tensor,
126 a: float = 0,
127 nonlinearity: str = "leaky_relu",
128) -> torch.Tensor:
129 r"""Fills the input `Tensor` with values according to the method
130 described in `Understanding the difficulty of training deep feedforward
131 neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
132 distribution. The resulting tensor will have values sampled from
133 :math:`\mathcal{U}(-a, a)` where
135 .. math::
137 a = \text{gain} \times \sqrt{\frac{6}{2(\text{fan_in} + \text{fan_out})}}
139 Also known as Glorot initialization.
141 Arguments:
142 tensor: an n-dimensional `torch.Tensor`
143 a: an optional parameter to the non-linear function
144 nonlinearity: the non linearity to compute the gain
146 Examples:
147 >>> w = torch.empty(3, 5, dtype=torch.complex64)
148 >>> c_nn.init.complex_xavier_uniform_(w, nonlinearity='relu')
150 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.xavier_uniform_` function
151 """
152 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
153 gain = nn.init.calculate_gain(nonlinearity, a)
154 std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) / math.sqrt(2)
155 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
157 return nn.init._no_grad_uniform_(tensor, -bound, bound)
160def complex_xavier_normal_(
161 tensor: torch.Tensor,
162 a: float = 0,
163 nonlinearity: str = "leaky_relu",
164) -> torch.Tensor:
165 r"""Fills the input `Tensor` with values according to the method
166 described in `Understanding the difficulty of training deep feedforward
167 neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
168 distribution. The resulting tensor will have values sampled from
169 :math:`\mathcal{N}(0, \text{std}^2)` where
171 .. math::
173 \text{std} = \text{gain} \times \sqrt{\frac{2}{2(\text{fan_in} + \text{fan_out})}}
175 Also known as Glorot initialization.
177 Arguments:
178 tensor: an n-dimensional `torch.Tensor`
179 a: an optional parameter to the non-linear function
180 nonlinearity: the non linearity to compute the gain
182 Examples:
183 >>> w = torch.empty(3, 5, dtype=torch.complex64)
184 >>> nn.init.complex_xavier_normal_(w, nonlinearity='relu')
186 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.xavier_normal_` function
187 """
188 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
189 gain = nn.init.calculate_gain(nonlinearity, a)
190 std = (gain * math.sqrt(2.0 / float(fan_in + fan_out))) / math.sqrt(2)
192 return nn.init._no_grad_normal_(tensor, 0.0, std)