Coverage for / home / runner / work / torchcvnn / torchcvnn / src / torchcvnn / nn / modules / normalization.py: 85%
133 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-24 13:56 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-24 13:56 +0000
1# MIT License
3# Copyright (c) 2024 Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Union, List
25import math
26import numbers
27import functools
28import operator
30# External imports
31import torch
32from torch import Size
33import torch.nn as nn
34import torch.nn.init as init
36# Local imports
37import torchcvnn.nn.modules.batchnorm as bn
39_shape_t = Union[int, List[int], Size]
42class LayerNorm(nn.Module):
43 r"""
44 Implementation of the torch.nn.LayerNorm for complex numbers.
46 Arguments:
47 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])`
48 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5
49 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True`
50 bias (bool): if set to `False`, the layer will not learn an additive bias
51 """
53 def __init__(
54 self,
55 normalized_shape: _shape_t,
56 eps: float = 1e-5,
57 elementwise_affine: bool = True,
58 bias: bool = True,
59 device: torch.device = None,
60 dtype: torch.dtype = torch.complex64,
61 ) -> None:
62 super().__init__()
63 if isinstance(normalized_shape, numbers.Integral):
64 normalized_shape = (normalized_shape,)
65 self.normalized_shape = tuple(normalized_shape)
66 self.eps = eps
68 self.elementwise_affine = elementwise_affine
70 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape)
71 if self.elementwise_affine:
72 self.weight = torch.nn.parameter.Parameter(
73 torch.empty((self.combined_dimensions, 2, 2), device=device)
74 )
75 if bias:
76 self.bias = torch.nn.parameter.Parameter(
77 torch.empty((self.combined_dimensions,), device=device, dtype=dtype)
78 )
79 else:
80 self.register_parameter("bias", None)
81 else:
82 self.register_parameter("weight", None)
83 self.register_parameter("bias", None)
84 self.reset_parameters()
86 def reset_parameters(self) -> None:
87 r"""
88 Initialize the weight and bias. The weight is initialized to a diagonal
89 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`.
91 The bias is initialized to :math:`0`.
92 """
93 with torch.no_grad():
94 if self.elementwise_affine:
95 # Initialize all the weights to zeros
96 init.zeros_(self.weight)
97 # And then fill in the diagonal with 1/sqrt(2)
98 # w is C, 2, 2
99 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0)
100 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0)
101 # Initialize all the biases to zero
102 init.zeros_(self.bias)
104 def forward(self, z: torch.Tensor) -> torch.Tensor:
105 """
106 Performs the forward pass
107 """
108 # z: *, normalized_shape[0] , ..., normalized_shape[-1]
109 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1)
111 # Compute the means
112 mus = z_ravel.mean(axis=-1) # normalized_shape[0]x normalized_shape[1], ...
114 # Center the inputs
115 z_centered = z_ravel - mus.unsqueeze(-1)
116 z_centered = torch.view_as_real(
117 z_centered
118 ) # combined_dimensions,num_samples, 2
120 # Transform the complex numbers as 2 reals to compute the variances and
121 # covariances
122 covs = bn.batch_cov(z_centered, centered=True)
124 # Invert the covariance to scale
125 invsqrt_covs = bn.inv_sqrt_2x2(
126 covs + self.eps * torch.eye(2, device=covs.device)
127 ) # combined_dimensions, 2, 2
128 # Note: the z_centered.transpose is to make
129 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples)
130 # So that the batch matrix multiply works as expected
131 # where invsqrt_covs is (combined_dimensions, 2, 2)
132 outz = torch.bmm(invsqrt_covs, z_centered.transpose(1, 2))
133 outz = outz.contiguous() # num_features, 2, BxHxW
135 # Shift by beta and scale by gamma
136 # weight is (num_features, 2, 2) real valued
137 if self.elementwise_affine:
138 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples
139 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2
140 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples
142 # bias is (C, ) complex dtype
143 if getattr(self, "bias", None) is not None:
144 outz += self.bias.view(-1, 1)
146 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions
148 outz = outz.view(z.shape)
150 return outz
152class GroupNorm(nn.Module):
153 r"""
154 Implementation of Group Normalization for complex numbers.
156 This class is adapted from pytorch :py:class:`torch.nn.GroupNorm`.
158 It implements `Group Normalization <https://arxiv.org/abs/1803.08494>`_
160 Arguments:
161 num_groups (int): number of groups to separate the channels into
162 num_channels (int): number of channels expected in input
163 eps (float): a value added to the denominator for numerical stability. Default: 1e-5
164 affine (bool): a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True`
165 """
167 def __init__(
168 self,
169 num_groups: int,
170 num_channels: int,
171 eps: float = 1e-5,
172 affine: bool = True,
173 device: torch.device = None,
174 dtype: torch.dtype = torch.complex64,
175 ) -> None:
176 super().__init__()
177 if num_channels % num_groups != 0:
178 raise ValueError('num_channels must be divisible by num_groups')
180 self.num_groups = num_groups
181 self.num_channels = num_channels
182 self.eps = eps
183 self.affine = affine
185 if self.affine:
186 # Weights are per channel (C), not per group
187 self.weight = torch.nn.parameter.Parameter(
188 torch.empty((num_channels, 2, 2), device=device)
189 )
190 self.bias = torch.nn.parameter.Parameter(
191 torch.empty((num_channels,), device=device, dtype=dtype)
192 )
193 else:
194 self.register_parameter("weight", None)
195 self.register_parameter("bias", None)
197 self.reset_parameters()
199 def reset_parameters(self) -> None:
200 if self.affine:
201 with torch.no_grad():
202 init.zeros_(self.weight)
203 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0)
204 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0)
205 init.zeros_(self.bias)
207 def forward(self, z: torch.Tensor) -> torch.Tensor:
208 # z: N, C, * (spatial dims)
209 N, C = z.shape[:2]
210 if C != self.num_channels:
211 raise ValueError(f"Expected {self.num_channels} channels, got {C}")
213 # Reshape to separate groups: (N, G, C//G, *)
214 z_grouped = z.view(N, self.num_groups, C // self.num_groups, *z.shape[2:])
216 # Flatten for statistics:
217 # Feature dim = N * G
218 # Sample dim = (C//G) * Spatial
219 z_flat = z_grouped.view(N * self.num_groups, -1)
221 # 1. Whitening (Group Norm logic)
222 mus = z_flat.mean(dim=-1) # (N*G,)
224 z_centered = z_flat - mus.unsqueeze(-1)
225 z_centered_real = torch.view_as_real(z_centered) # (N*G, Sample, 2)
227 covs = bn.batch_cov(z_centered_real, centered=True) # (N*G, 2, 2)
229 invsqrt_covs = bn.inv_sqrt_2x2(
230 covs + self.eps * torch.eye(2, device=covs.device)
231 ) # (N*G, 2, 2)
233 # Apply whitening
234 outz = torch.bmm(invsqrt_covs, z_centered_real.transpose(1, 2)) # (N*G, 2, Sample)
235 outz = outz.transpose(1, 2).contiguous() # (N*G, Sample, 2)
236 outz = torch.view_as_complex(outz) # (N*G, Sample)
238 # Reshape back to (N, C, *) to prepare for Affine (which is per-channel)
239 outz = outz.view(z.shape)
241 # 2. Affine Transformation (if enabled)
242 if self.affine:
243 # We need to apply weights (C, 2, 2) to input (N, C, *)
244 # Move C to front to use batch matmul: (C, N, *)
245 # Flatten tail: (C, N*Spatial)
246 outz_permuted = outz.transpose(0, 1).reshape(C, -1)
247 outz_real = torch.view_as_real(outz_permuted) # (C, N*Spatial, 2)
249 # Apply Gamma
250 # weight: (C, 2, 2)
251 # input: (C, N*Spatial, 2) -> transpose to (C, 2, N*Spatial)
252 outz_affine = torch.bmm(self.weight, outz_real.transpose(1, 2))
253 outz_affine = outz_affine.transpose(1, 2).contiguous() # (C, N*Spatial, 2)
254 outz_affine = torch.view_as_complex(outz_affine) # (C, N*Spatial)
256 # Apply Beta (bias)
257 outz_affine = outz_affine + self.bias.unsqueeze(-1)
259 # Reshape back to (C, N, *) then (N, C, *)
260 outz_affine = outz_affine.view(C, N, *z.shape[2:]).transpose(0, 1)
261 return outz_affine
263 return outz
265class RMSNorm(nn.Module):
266 r"""
267 Implementation of the torch.nn.RMSNorm for complex numbers.
269 Arguments:
270 normalized_shape (int or list or torch.Size): input shape from an expected input of size :math:`(*, normalized\_shape[0], normalized\_shape[1], ..., normalized\_shape[-1])`
271 eps (float) – a value added to the denominator for numerical stability. Default: 1e-5
272 elementwise_affine (bool): a boolean value that when set to `True`, this module has learnable per-element affine parameters initialized to a diagonal matrix with diagonal element :math:`\frac{1}{\sqrt{2}}` (for weights) and zeros (for biases). Default: `True`
273 """
275 def __init__(
276 self,
277 normalized_shape: _shape_t,
278 eps: float = 1e-5,
279 elementwise_affine: bool = True,
280 device: torch.device = None,
281 dtype: torch.dtype = torch.complex64,
282 ) -> None:
283 super().__init__()
284 if isinstance(normalized_shape, numbers.Integral):
285 normalized_shape = (normalized_shape,)
286 self.normalized_shape = tuple(normalized_shape)
287 self.eps = eps
289 self.elementwise_affine = elementwise_affine
291 self.combined_dimensions = functools.reduce(operator.mul, self.normalized_shape)
292 if self.elementwise_affine:
293 self.weight = torch.nn.parameter.Parameter(
294 torch.empty((self.combined_dimensions, 2, 2), device=device)
295 )
296 else:
297 self.register_parameter("weight", None)
298 self.reset_parameters()
300 def reset_parameters(self) -> None:
301 r"""
302 Initialize the weights. The weight is initialized to a diagonal
303 matrix with diagonal :math:`\frac{1}{\sqrt{2}}`.
304 """
305 with torch.no_grad():
306 if self.elementwise_affine:
307 # Initialize all the weights to zeros
308 init.zeros_(self.weight)
309 # And then fill in the diagonal with 1/sqrt(2)
310 # w is C, 2, 2
311 self.weight.view(-1, 2, 2)[:, 0, 0] = 1 / math.sqrt(2.0)
312 self.weight.view(-1, 2, 2)[:, 1, 1] = 1 / math.sqrt(2.0)
314 def forward(self, z: torch.Tensor) -> torch.Tensor:
315 """
316 Performs the forward pass
317 """
318 # z: *, normalized_shape[0] , ..., normalized_shape[-1]
319 z_ravel = z.view(-1, self.combined_dimensions).transpose(0, 1)
321 z_ravel = torch.view_as_real(z_ravel) # combined_dimensions,num_samples, 2
323 # Transform the complex numbers as 2 reals to compute the variances and
324 # covariances
325 covs = bn.batch_cov(z_ravel, centered=True)
327 # Invert the covariance to scale
328 invsqrt_covs = bn.inv_sqrt_2x2(
329 covs + self.eps * torch.eye(2, device=covs.device)
330 ) # combined_dimensions, 2, 2
331 # Note: the z_centered.transpose is to make
332 # z_centered from (combined_dimensions, num_samples, 2) to (combined_dimensions, 2, num_samples)
333 # So that the batch matrix multiply works as expected
334 # where invsqrt_covs is (combined_dimensions, 2, 2)
335 outz = torch.bmm(invsqrt_covs, z_ravel.transpose(1, 2))
336 outz = outz.contiguous() # num_features, 2, BxHxW
338 # Scale by gamma
339 # weight is (num_features, 2, 2) real valued
340 if self.elementwise_affine:
341 outz = torch.bmm(self.weight, outz) # combined_dimensions, 2, num_samples
342 outz = outz.transpose(1, 2).contiguous() # combined_dimensions, num_samples, 2
343 outz = torch.view_as_complex(outz) # combined_dimensions, num_samples
345 # bias is (C, ) complex dtype
346 outz = outz.transpose(0, 1).contiguous() # num_samples, comnbined_dimensions
348 outz = outz.view(z.shape)
350 return outz