torchcvnn.nn.init¶
- torchcvnn.nn.modules.initialization.complex_kaiming_normal_(tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') Tensor [source]¶
Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from \(\mathcal{N}(0, \text{std}^2)\) where
\[\text{std} = \frac{\text{gain}}{\sqrt{2\text{fan_mode}}}\]Also known as He initialization.
- Parameters:
tensor – an n-dimensional torch.Tensor
a – the negative slope of the rectifier used after this layer (only used with
'leaky_relu'
)mode – either
'fan_in'
(default) or'fan_out'
. Choosing'fan_in'
preserves the magnitude of the variance of the weights in the forward pass. Choosing'fan_out'
preserves the magnitudes in the backwards pass.nonlinearity – the non-linear function (nn.functional name), recommended to use only with
'relu'
or'leaky_relu'
(default).
Examples
>>> w = torch.empty(3, 5, dtype=torch.complex64) >>> c_nn.init.complex_kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
This implementation is a minor adaptation of the
torch.nn.init.kaiming_normal_()
function
- torchcvnn.nn.modules.initialization.complex_kaiming_uniform_(tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') Tensor [source]¶
Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from \(\mathcal{U}(-\text{bound}, \text{bound})\) where
\[\text{bound} = \text{gain} \times \sqrt{\frac{3}{2\text{fan_mode}}}\]Also known as He initialization.
- Parameters:
tensor – an n-dimensional :external:py:class`torch.Tensor`
a – the negative slope of the rectifier used after this layer (only used with
'leaky_relu'
)mode – either
'fan_in'
(default) or'fan_out'
. Choosing'fan_in'
preserves the magnitude of the variance of the weights in the forward pass. Choosing'fan_out'
preserves the magnitudes in the backwards pass.nonlinearity – the non-linear function (nn.functional name), recommended to use only with
'relu'
or'leaky_relu'
(default).
Examples
>>> w = torch.empty(3, 5, dtype=torch.complex64) >>> c_nn.init.complex_kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
This implementation is a minor adaptation of the
torch.nn.init.kaiming_uniform_()
function
- torchcvnn.nn.modules.initialization.complex_xavier_normal_(tensor: Tensor, a: float = 0, nonlinearity: str = 'leaky_relu') Tensor [source]¶
Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution. The resulting tensor will have values sampled from \(\mathcal{N}(0, \text{std}^2)\) where
\[\text{std} = \text{gain} \times \sqrt{\frac{2}{2(\text{fan_in} + \text{fan_out})}}\]Also known as Glorot initialization.
- Parameters:
tensor – an n-dimensional torch.Tensor
a – an optional parameter to the non-linear function
nonlinearity – the non linearity to compute the gain
Examples
>>> w = torch.empty(3, 5, dtype=torch.complex64) >>> nn.init.complex_xavier_normal_(w, nonlinearity='relu')
This implementation is a minor adaptation of the
torch.nn.init.xavier_normal_()
function
- torchcvnn.nn.modules.initialization.complex_xavier_uniform_(tensor: Tensor, a: float = 0, nonlinearity: str = 'leaky_relu') Tensor [source]¶
Fills the input Tensor with values according to the method described in Understanding the difficulty of training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution. The resulting tensor will have values sampled from \(\mathcal{U}(-a, a)\) where
\[a = \text{gain} \times \sqrt{\frac{6}{2(\text{fan_in} + \text{fan_out})}}\]Also known as Glorot initialization.
- Parameters:
tensor – an n-dimensional torch.Tensor
a – an optional parameter to the non-linear function
nonlinearity – the non linearity to compute the gain
Examples
>>> w = torch.empty(3, 5, dtype=torch.complex64) >>> c_nn.init.complex_xavier_uniform_(w, nonlinearity='relu')
This implementation is a minor adaptation of the
torch.nn.init.xavier_uniform_()
function