Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/normalization.py: 95%

83 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Union, List 

25import math 

26import numbers 

27import functools 

28import operator 

29 

30# External imports 

31import torch 

32from torch import Size 

33import torch.nn as nn 

34import torch.nn.init as init 

35 

36# Local imports 

37import torchcvnn.nn.modules.batchnorm as bn 

38 

39_shape_t = Union[int, List[int], Size] 

40 

41 

42class LayerNorm(nn.Module): 

43 r""" 

44 Implementation of the torch.nn.LayerNorm for complex numbers. 

45 

46 Arguments: 

47 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])` 

48 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5 

49 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True` 

50 bias (bool): if set to `False`, the layer will not learn an additive bias 

51 """ 

52 

53 def __init__( 

54 self, 

55 normalized_shape: _shape_t, 

56 eps: float = 1e-5, 

57 elementwise_affine: bool = True, 

58 bias: bool = True, 

59 device: torch.device = None, 

60 dtype: torch.dtype = torch.complex64, 

61 ) -> None: 

62 super().__init__() 

63 if isinstance(normalized_shape, numbers.Integral): 

64 normalized_shape = (normalized_shape,) 

65 self.normalized_shape = tuple(normalized_shape) 

66 self.eps = eps 

67 

68 self.elementwise_affine = elementwise_affine 

69 

70 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape) 

71 if self.elementwise_affine: 

72 self.weight = torch.nn.parameter.Parameter( 

73 torch.empty((self.combined_dimensions, 2, 2), device=device) 

74 ) 

75 if bias: 

76 self.bias = torch.nn.parameter.Parameter( 

77 torch.empty((self.combined_dimensions,), device=device, dtype=dtype) 

78 ) 

79 else: 

80 self.register_parameter("bias", None) 

81 else: 

82 self.register_parameter("weight", None) 

83 self.register_parameter("bias", None) 

84 self.reset_parameters() 

85 

86 def reset_parameters(self) -> None: 

87 r""" 

88 Initialize the weight and bias. The weight is initialized to a diagonal 

89 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`. 

90 

91 The bias is initialized to :math:`0`. 

92 """ 

93 with torch.no_grad(): 

94 if self.elementwise_affine: 

95 # Initialize all the weights to zeros 

96 init.zeros_(self.weight) 

97 # And then fill in the diagonal with 1/sqrt(2) 

98 # w is C, 2, 2 

99 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0) 

100 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0) 

101 # Initialize all the biases to zero 

102 init.zeros_(self.bias) 

103 

104 def forward(self, z: torch.Tensor) -> torch.Tensor: 

105 """ 

106 Performs the forward pass 

107 """ 

108 # z: *, normalized_shape[0] , ..., normalized_shape[-1] 

109 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1) 

110 

111 # Compute the means 

112 mus = z_ravel.mean(axis=-1) # normalized_shape[0]x normalized_shape[1], ... 

113 

114 # Center the inputs 

115 z_centered = z_ravel - mus.unsqueeze(-1) 

116 z_centered = torch.view_as_real( 

117 z_centered 

118 ) # combined_dimensions,num_samples, 2 

119 

120 # Transform the complex numbers as 2 reals to compute the variances and 

121 # covariances 

122 covs = bn.batch_cov(z_centered, centered=True) 

123 

124 # Invert the covariance to scale 

125 invsqrt_covs = bn.inv_sqrt_2x2( 

126 covs + self.eps * torch.eye(2, device=covs.device) 

127 ) # combined_dimensions, 2, 2 

128 # Note: the z_centered.transpose is to make 

129 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples) 

130 # So that the batch matrix multiply works as expected 

131 # where invsqrt_covs is (combined_dimensions, 2, 2) 

132 outz = torch.bmm(invsqrt_covs, z_centered.transpose(1, 2)) 

133 outz = outz.contiguous() # num_features, 2, BxHxW 

134 

135 # Shift by beta and scale by gamma 

136 # weight is (num_features, 2, 2) real valued 

137 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples 

138 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2 

139 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples 

140 

141 # bias is (C, ) complex dtype 

142 outz += self.bias.view(-1, 1) 

143 

144 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions 

145 

146 outz = outz.view(z.shape) 

147 

148 return outz 

149 

150 

151class RMSNorm(nn.Module): 

152 r""" 

153 Implementation of the torch.nn.RMSNorm for complex numbers. 

154 

155 Arguments: 

156 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])` 

157 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5 

158 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True` 

159 """ 

160 

161 def __init__( 

162 self, 

163 normalized_shape: _shape_t, 

164 eps: float = 1e-5, 

165 elementwise_affine: bool = True, 

166 device: torch.device = None, 

167 dtype: torch.dtype = torch.complex64, 

168 ) -> None: 

169 super().__init__() 

170 if isinstance(normalized_shape, numbers.Integral): 

171 normalized_shape = (normalized_shape,) 

172 self.normalized_shape = tuple(normalized_shape) 

173 self.eps = eps 

174 

175 self.elementwise_affine = elementwise_affine 

176 

177 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape) 

178 if self.elementwise_affine: 

179 self.weight = torch.nn.parameter.Parameter( 

180 torch.empty((self.combined_dimensions, 2, 2), device=device) 

181 ) 

182 else: 

183 self.register_parameter("weight", None) 

184 self.reset_parameters() 

185 

186 def reset_parameters(self) -> None: 

187 r""" 

188 Initialize the weights. The weight is initialized to a diagonal 

189 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`. 

190 """ 

191 with torch.no_grad(): 

192 if self.elementwise_affine: 

193 # Initialize all the weights to zeros 

194 init.zeros_(self.weight) 

195 # And then fill in the diagonal with 1/sqrt(2) 

196 # w is C, 2, 2 

197 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0) 

198 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0) 

199 

200 def forward(self, z: torch.Tensor) -> torch.Tensor: 

201 """ 

202 Performs the forward pass 

203 """ 

204 # z: *, normalized_shape[0] , ..., normalized_shape[-1] 

205 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1) 

206 

207 z_ravel = torch.view_as_real(z_ravel) # combined_dimensions,num_samples, 2 

208 

209 # Transform the complex numbers as 2 reals to compute the variances and 

210 # covariances 

211 covs = bn.batch_cov(z_ravel, centered=True) 

212 

213 # Invert the covariance to scale 

214 invsqrt_covs = bn.inv_sqrt_2x2( 

215 covs + self.eps * torch.eye(2, device=covs.device) 

216 ) # combined_dimensions, 2, 2 

217 # Note: the z_centered.transpose is to make 

218 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples) 

219 # So that the batch matrix multiply works as expected 

220 # where invsqrt_covs is (combined_dimensions, 2, 2) 

221 outz = torch.bmm(invsqrt_covs, z_ravel.transpose(1, 2)) 

222 outz = outz.contiguous() # num_features, 2, BxHxW 

223 

224 # Scale by gamma 

225 # weight is (num_features, 2, 2) real valued 

226 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples 

227 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2 

228 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples 

229 

230 # bias is (C, ) complex dtype 

231 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions 

232 

233 outz = outz.view(z.shape) 

234 

235 return outz