Coverage for  / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / normalization.py: 85%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-24 13:56 +0000

1# MIT License 

2 

3# Copyright (c) 2024 Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Union, List 

25import math 

26import numbers 

27import functools 

28import operator 

29 

30# External imports 

31import torch 

32from torch import Size 

33import torch.nn as nn 

34import torch.nn.init as init 

35 

36# Local imports 

37import torchcvnn.nn.modules.batchnorm as bn 

38 

39_shape_t = Union[int, List[int], Size] 

40 

41 

42class LayerNorm(nn.Module): 

43 r""" 

44 Implementation of the torch.nn.LayerNorm for complex numbers. 

45 

46 Arguments: 

47 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])` 

48 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5 

49 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True` 

50 bias (bool): if set to `False`, the layer will not learn an additive bias 

51 """ 

52 

53 def __init__( 

54 self, 

55 normalized_shape: _shape_t, 

56 eps: float = 1e-5, 

57 elementwise_affine: bool = True, 

58 bias: bool = True, 

59 device: torch.device = None, 

60 dtype: torch.dtype = torch.complex64, 

61 ) -> None: 

62 super().__init__() 

63 if isinstance(normalized_shape, numbers.Integral): 

64 normalized_shape = (normalized_shape,) 

65 self.normalized_shape = tuple(normalized_shape) 

66 self.eps = eps 

67 

68 self.elementwise_affine = elementwise_affine 

69 

70 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape) 

71 if self.elementwise_affine: 

72 self.weight = torch.nn.parameter.Parameter( 

73 torch.empty((self.combined_dimensions, 2, 2), device=device) 

74 ) 

75 if bias: 

76 self.bias = torch.nn.parameter.Parameter( 

77 torch.empty((self.combined_dimensions,), device=device, dtype=dtype) 

78 ) 

79 else: 

80 self.register_parameter("bias", None) 

81 else: 

82 self.register_parameter("weight", None) 

83 self.register_parameter("bias", None) 

84 self.reset_parameters() 

85 

86 def reset_parameters(self) -> None: 

87 r""" 

88 Initialize the weight and bias. The weight is initialized to a diagonal 

89 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`. 

90 

91 The bias is initialized to :math:`0`. 

92 """ 

93 with torch.no_grad(): 

94 if self.elementwise_affine: 

95 # Initialize all the weights to zeros 

96 init.zeros_(self.weight) 

97 # And then fill in the diagonal with 1/sqrt(2) 

98 # w is C, 2, 2 

99 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0) 

100 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0) 

101 # Initialize all the biases to zero 

102 init.zeros_(self.bias) 

103 

104 def forward(self, z: torch.Tensor) -> torch.Tensor: 

105 """ 

106 Performs the forward pass 

107 """ 

108 # z: *, normalized_shape[0] , ..., normalized_shape[-1] 

109 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1) 

110 

111 # Compute the means 

112 mus = z_ravel.mean(axis=-1) # normalized_shape[0]x normalized_shape[1], ... 

113 

114 # Center the inputs 

115 z_centered = z_ravel - mus.unsqueeze(-1) 

116 z_centered = torch.view_as_real( 

117 z_centered 

118 ) # combined_dimensions,num_samples, 2 

119 

120 # Transform the complex numbers as 2 reals to compute the variances and 

121 # covariances 

122 covs = bn.batch_cov(z_centered, centered=True) 

123 

124 # Invert the covariance to scale 

125 invsqrt_covs = bn.inv_sqrt_2x2( 

126 covs + self.eps * torch.eye(2, device=covs.device) 

127 ) # combined_dimensions, 2, 2 

128 # Note: the z_centered.transpose is to make 

129 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples) 

130 # So that the batch matrix multiply works as expected 

131 # where invsqrt_covs is (combined_dimensions, 2, 2) 

132 outz = torch.bmm(invsqrt_covs, z_centered.transpose(1, 2)) 

133 outz = outz.contiguous() # num_features, 2, BxHxW 

134 

135 # Shift by beta and scale by gamma 

136 # weight is (num_features, 2, 2) real valued 

137 if self.elementwise_affine: 

138 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples 

139 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2 

140 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples 

141 

142 # bias is (C, ) complex dtype 

143 if getattr(self, "bias", None) is not None: 

144 outz += self.bias.view(-1, 1) 

145 

146 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions 

147 

148 outz = outz.view(z.shape) 

149 

150 return outz 

151 

152class GroupNorm(nn.Module): 

153 r""" 

154 Implementation of Group Normalization for complex numbers. 

155 

156 This class is adapted from pytorch :py:class:`torch.nn.GroupNorm`.  

157 

158 It implements `Group Normalization <https://arxiv.org/abs/1803.08494>`_ 

159  

160 Arguments: 

161 num_groups (int): number of groups to separate the channels into 

162 num_channels (int): number of channels expected in input 

163 eps (float): a value added to the denominator for numerical stability. Default: 1e-5 

164 affine (bool): a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True` 

165 """ 

166 

167 def __init__( 

168 self, 

169 num_groups: int, 

170 num_channels: int, 

171 eps: float = 1e-5, 

172 affine: bool = True, 

173 device: torch.device = None, 

174 dtype: torch.dtype = torch.complex64, 

175 ) -> None: 

176 super().__init__() 

177 if num_channels % num_groups != 0: 

178 raise ValueError('num_channels must be divisible by num_groups') 

179 

180 self.num_groups = num_groups 

181 self.num_channels = num_channels 

182 self.eps = eps 

183 self.affine = affine 

184 

185 if self.affine: 

186 # Weights are per channel (C), not per group 

187 self.weight = torch.nn.parameter.Parameter( 

188 torch.empty((num_channels, 2, 2), device=device) 

189 ) 

190 self.bias = torch.nn.parameter.Parameter( 

191 torch.empty((num_channels,), device=device, dtype=dtype) 

192 ) 

193 else: 

194 self.register_parameter("weight", None) 

195 self.register_parameter("bias", None) 

196 

197 self.reset_parameters() 

198 

199 def reset_parameters(self) -> None: 

200 if self.affine: 

201 with torch.no_grad(): 

202 init.zeros_(self.weight) 

203 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0) 

204 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0) 

205 init.zeros_(self.bias) 

206 

207 def forward(self, z: torch.Tensor) -> torch.Tensor: 

208 # z: N, C, * (spatial dims) 

209 N, C = z.shape[:2] 

210 if C != self.num_channels: 

211 raise ValueError(f"Expected {self.num_channels} channels, got {C}") 

212 

213 # Reshape to separate groups: (N, G, C//G, *) 

214 z_grouped = z.view(N, self.num_groups, C // self.num_groups, *z.shape[2:]) 

215 

216 # Flatten for statistics:  

217 # Feature dim = N * G 

218 # Sample dim = (C//G) * Spatial 

219 z_flat = z_grouped.view(N * self.num_groups, -1) 

220 

221 # 1. Whitening (Group Norm logic) 

222 mus = z_flat.mean(dim=-1) # (N*G,) 

223 

224 z_centered = z_flat - mus.unsqueeze(-1) 

225 z_centered_real = torch.view_as_real(z_centered) # (N*G, Sample, 2) 

226 

227 covs = bn.batch_cov(z_centered_real, centered=True) # (N*G, 2, 2) 

228 

229 invsqrt_covs = bn.inv_sqrt_2x2( 

230 covs + self.eps * torch.eye(2, device=covs.device) 

231 ) # (N*G, 2, 2) 

232 

233 # Apply whitening 

234 outz = torch.bmm(invsqrt_covs, z_centered_real.transpose(1, 2)) # (N*G, 2, Sample) 

235 outz = outz.transpose(1, 2).contiguous() # (N*G, Sample, 2) 

236 outz = torch.view_as_complex(outz) # (N*G, Sample) 

237 

238 # Reshape back to (N, C, *) to prepare for Affine (which is per-channel) 

239 outz = outz.view(z.shape) 

240 

241 # 2. Affine Transformation (if enabled) 

242 if self.affine: 

243 # We need to apply weights (C, 2, 2) to input (N, C, *) 

244 # Move C to front to use batch matmul: (C, N, *) 

245 # Flatten tail: (C, N*Spatial) 

246 outz_permuted = outz.transpose(0, 1).reshape(C, -1) 

247 outz_real = torch.view_as_real(outz_permuted) # (C, N*Spatial, 2) 

248 

249 # Apply Gamma 

250 # weight: (C, 2, 2) 

251 # input: (C, N*Spatial, 2) -> transpose to (C, 2, N*Spatial) 

252 outz_affine = torch.bmm(self.weight, outz_real.transpose(1, 2)) 

253 outz_affine = outz_affine.transpose(1, 2).contiguous() # (C, N*Spatial, 2) 

254 outz_affine = torch.view_as_complex(outz_affine) # (C, N*Spatial) 

255 

256 # Apply Beta (bias) 

257 outz_affine = outz_affine + self.bias.unsqueeze(-1) 

258 

259 # Reshape back to (C, N, *) then (N, C, *) 

260 outz_affine = outz_affine.view(C, N, *z.shape[2:]).transpose(0, 1) 

261 return outz_affine 

262 

263 return outz 

264 

265class RMSNorm(nn.Module): 

266 r""" 

267 Implementation of the torch.nn.RMSNorm for complex numbers. 

268 

269 Arguments: 

270 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])` 

271 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5 

272 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True` 

273 """ 

274 

275 def __init__( 

276 self, 

277 normalized_shape: _shape_t, 

278 eps: float = 1e-5, 

279 elementwise_affine: bool = True, 

280 device: torch.device = None, 

281 dtype: torch.dtype = torch.complex64, 

282 ) -> None: 

283 super().__init__() 

284 if isinstance(normalized_shape, numbers.Integral): 

285 normalized_shape = (normalized_shape,) 

286 self.normalized_shape = tuple(normalized_shape) 

287 self.eps = eps 

288 

289 self.elementwise_affine = elementwise_affine 

290 

291 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape) 

292 if self.elementwise_affine: 

293 self.weight = torch.nn.parameter.Parameter( 

294 torch.empty((self.combined_dimensions, 2, 2), device=device) 

295 ) 

296 else: 

297 self.register_parameter("weight", None) 

298 self.reset_parameters() 

299 

300 def reset_parameters(self) -> None: 

301 r""" 

302 Initialize the weights. The weight is initialized to a diagonal 

303 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`. 

304 """ 

305 with torch.no_grad(): 

306 if self.elementwise_affine: 

307 # Initialize all the weights to zeros 

308 init.zeros_(self.weight) 

309 # And then fill in the diagonal with 1/sqrt(2) 

310 # w is C, 2, 2 

311 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0) 

312 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0) 

313 

314 def forward(self, z: torch.Tensor) -> torch.Tensor: 

315 """ 

316 Performs the forward pass 

317 """ 

318 # z: *, normalized_shape[0] , ..., normalized_shape[-1] 

319 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1) 

320 

321 z_ravel = torch.view_as_real(z_ravel) # combined_dimensions,num_samples, 2 

322 

323 # Transform the complex numbers as 2 reals to compute the variances and 

324 # covariances 

325 covs = bn.batch_cov(z_ravel, centered=True) 

326 

327 # Invert the covariance to scale 

328 invsqrt_covs = bn.inv_sqrt_2x2( 

329 covs + self.eps * torch.eye(2, device=covs.device) 

330 ) # combined_dimensions, 2, 2 

331 # Note: the z_centered.transpose is to make 

332 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples) 

333 # So that the batch matrix multiply works as expected 

334 # where invsqrt_covs is (combined_dimensions, 2, 2) 

335 outz = torch.bmm(invsqrt_covs, z_ravel.transpose(1, 2)) 

336 outz = outz.contiguous() # num_features, 2, BxHxW 

337 

338 # Scale by gamma 

339 # weight is (num_features, 2, 2) real valued 

340 if self.elementwise_affine: 

341 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples 

342 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2 

343 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples 

344 

345 # bias is (C, ) complex dtype 

346 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions 

347 

348 outz = outz.view(z.shape) 

349 

350 return outz