Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/initialization.py: 93%

29 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Quentin Gabot 

4 

5# The following is an adaptation of the initialization code from 

6# PyTorch library for the complex-valued neural networks 

7 

8# Permission is hereby granted, free of charge, to any person obtaining a copy 

9# of this software and associated documentation files (the "Software"), to deal 

10# in the Software without restriction, including without limitation the rights 

11# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

12# copies of the Software, and to permit persons to whom the Software is 

13# furnished to do so, subject to the following conditions: 

14 

15# The above copyright notice and this permission notice shall be included in 

16# all copies or substantial portions of the Software. 

17 

18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

24# SOFTWARE. 

25 

26# Standard imports 

27import math 

28import warnings 

29 

30# External imports 

31import torch 

32import torch.nn as nn 

33 

34 

35def complex_kaiming_normal_( 

36 tensor: torch.Tensor, 

37 a: float = 0, 

38 mode: str = "fan_in", 

39 nonlinearity: str = "leaky_relu", 

40) -> torch.Tensor: 

41 r"""Fills the input `Tensor` with values according to the method 

42 described in `Delving deep into rectifiers: Surpassing human-level 

43 performance on ImageNet classification` - He, K. et al. (2015), using a 

44 normal distribution. The resulting tensor will have values sampled from 

45 :math:`\mathcal{N}(0, \text{std}^2)` where 

46 

47 .. math:: 

48 \text{std} = \frac{\text{gain}}{\sqrt{2\text{fan_mode}}} 

49 

50 Also known as He initialization. 

51 

52 Arguments: 

53 tensor: an n-dimensional `torch.Tensor` 

54 a: the negative slope of the rectifier used after this layer (only 

55 used with ``'leaky_relu'``) 

56 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 

57 preserves the magnitude of the variance of the weights in the 

58 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 

59 backwards pass. 

60 nonlinearity: the non-linear function (`nn.functional` name), 

61 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 

62 

63 Examples: 

64 >>> w = torch.empty(3, 5, dtype=torch.complex64) 

65 >>> c_nn.init.complex_kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 

66 

67 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.kaiming_normal_` function 

68 """ 

69 

70 fan = nn.init._calculate_correct_fan(tensor, mode) 

71 gain = nn.init.calculate_gain(nonlinearity, a) 

72 std = (gain / math.sqrt(fan)) / math.sqrt(2) 

73 

74 return nn.init._no_grad_normal_(tensor, 0.0, std) 

75 

76 

77def complex_kaiming_uniform_( 

78 tensor: torch.Tensor, 

79 a: float = 0, 

80 mode: str = "fan_in", 

81 nonlinearity: str = "leaky_relu", 

82) -> torch.Tensor: 

83 r"""Fills the input `Tensor` with values according to the method 

84 described in `Delving deep into rectifiers: Surpassing human-level 

85 performance on ImageNet classification` - He, K. et al. (2015), using a 

86 uniform distribution. The resulting tensor will have values sampled from 

87 :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 

88 

89 .. math:: 

90 

91 \text{bound} = \text{gain} \times \sqrt{\frac{3}{2\text{fan_mode}}} 

92 

93 Also known as He initialization. 

94 

95 Arguments: 

96 tensor: an n-dimensional :external:py:class`torch.Tensor` 

97 a: the negative slope of the rectifier used after this layer (only 

98 used with ``'leaky_relu'``) 

99 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 

100 preserves the magnitude of the variance of the weights in the 

101 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 

102 backwards pass. 

103 nonlinearity: the non-linear function (`nn.functional` name), 

104 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 

105 

106 Examples: 

107 >>> w = torch.empty(3, 5, dtype=torch.complex64) 

108 >>> c_nn.init.complex_kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 

109 

110 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.kaiming_uniform_` function 

111 """ 

112 

113 if 0 in tensor.shape: 

114 warnings.warn("Initializing zero-element tensors is a no-op") 

115 return tensor 

116 fan = nn.init._calculate_correct_fan(tensor, mode) 

117 gain = nn.init.calculate_gain(nonlinearity, a) 

118 std = gain / math.sqrt(fan) / math.sqrt(2) 

119 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 

120 

121 return nn.init._no_grad_uniform_(tensor, -bound, bound) 

122 

123 

124def complex_xavier_uniform_( 

125 tensor: torch.Tensor, 

126 a: float = 0, 

127 nonlinearity: str = "leaky_relu", 

128) -> torch.Tensor: 

129 r"""Fills the input `Tensor` with values according to the method 

130 described in `Understanding the difficulty of training deep feedforward 

131 neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform 

132 distribution. The resulting tensor will have values sampled from 

133 :math:`\mathcal{U}(-a, a)` where 

134 

135 .. math:: 

136 

137 a = \text{gain} \times \sqrt{\frac{6}{2(\text{fan_in} + \text{fan_out})}} 

138 

139 Also known as Glorot initialization. 

140 

141 Arguments: 

142 tensor: an n-dimensional `torch.Tensor` 

143 a: an optional parameter to the non-linear function 

144 nonlinearity: the non linearity to compute the gain 

145 

146 Examples: 

147 >>> w = torch.empty(3, 5, dtype=torch.complex64) 

148 >>> c_nn.init.complex_xavier_uniform_(w, nonlinearity='relu') 

149 

150 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.xavier_uniform_` function 

151 """ 

152 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 

153 gain = nn.init.calculate_gain(nonlinearity, a) 

154 std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) / math.sqrt(2) 

155 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 

156 

157 return nn.init._no_grad_uniform_(tensor, -bound, bound) 

158 

159 

160def complex_xavier_normal_( 

161 tensor: torch.Tensor, 

162 a: float = 0, 

163 nonlinearity: str = "leaky_relu", 

164) -> torch.Tensor: 

165 r"""Fills the input `Tensor` with values according to the method 

166 described in `Understanding the difficulty of training deep feedforward 

167 neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal 

168 distribution. The resulting tensor will have values sampled from 

169 :math:`\mathcal{N}(0, \text{std}^2)` where 

170 

171 .. math:: 

172 

173 \text{std} = \text{gain} \times \sqrt{\frac{2}{2(\text{fan_in} + \text{fan_out})}} 

174 

175 Also known as Glorot initialization. 

176 

177 Arguments: 

178 tensor: an n-dimensional `torch.Tensor` 

179 a: an optional parameter to the non-linear function 

180 nonlinearity: the non linearity to compute the gain 

181 

182 Examples: 

183 >>> w = torch.empty(3, 5, dtype=torch.complex64) 

184 >>> nn.init.complex_xavier_normal_(w, nonlinearity='relu') 

185 

186 This implementation is a minor adaptation of the :external:py:func:`torch.nn.init.xavier_normal_` function 

187 """ 

188 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 

189 gain = nn.init.calculate_gain(nonlinearity, a) 

190 std = (gain * math.sqrt(2.0 / float(fan_in + fan_out))) / math.sqrt(2) 

191 

192 return nn.init._no_grad_normal_(tensor, 0.0, std)