Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/batchnorm.py: 83%
112 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24import math
26# External imports
27import torch
28import torch.nn as nn
29import torch.nn.init as init
32def batch_cov(points: torch.Tensor, centered: bool = False) -> torch.Tensor:
33 """
34 Batched covariance computation
35 Adapted from : https://stackoverflow.com/a/71357620/2164582
37 Arguments:
38 points: the (B, N, D) input tensor from which to compute B covariances
39 centered: If `True`, assumes for every batch, the N vectors are centered. default: `False`
41 Returns:
42 bcov: the covariances as a `(B, D, D)` tensor
43 """
44 B, N, D = points.size()
45 if not centered:
46 mean = points.mean(dim=1).unsqueeze(1)
47 diffs = (points - mean).reshape(B * N, D)
48 else:
49 diffs = points.reshape(B * N, D)
50 prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
51 bcov = prods.sum(dim=1) / (N - 1) # Unbiased estimate
52 return bcov # (B, D, D)
55def inv_2x2(M: torch.Tensor) -> torch.Tensor:
56 r"""
57 Computes the inverse of a tensor of shape [N, 2, 2].
59 If we denote
61 .. math::
63 M = \begin{pmatrix} a & b \\ c & d \end{pmatrix}
65 The inverse is given by
67 .. math::
69 M^{-1} = \frac{1}{Det M} Adj(M) = \frac{1}{ad - bc}\begin{pmatrix}d & -b \\ -c & a\end{pmatrix}
71 Arguments:
72 M: a batch of 2x2 tensors to invert, i.e. a :math:`(B, 2, 2)` tensor
73 """
74 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1)
76 M_adj = M.clone()
77 M_adj[:, 0, 0], M_adj[:, 1, 1] = M[:, 1, 1], M[:, 0, 0]
78 M_adj[:, 0, 1] *= -1
79 M_adj[:, 1, 0] *= -1
80 M_inv = 1 / det * M_adj
81 return M_inv
84def sqrt_2x2(M: torch.Tensor) -> torch.Tensor:
85 r"""
86 Computes the square root of the tensor of shape [N, 2, 2].
88 If we denote
90 .. math::
92 M = \begin{pmatrix} a & b \\ c & d \end{pmatrix}
94 The square root is given by :
96 .. math::
98 \begin{align}
99 \sqrt{M} &= \frac{1}{t} ( M + \sqrt{Det M} I)\\
100 t &= \sqrt{Tr M + 2 \sqrt{Det M}}
101 \end{align}
103 Arguments:
104 M: a batch of 2x2 tensors to invert, i.e. a :math:`(B, 2, 2)` tensor
105 """
106 N = M.shape[0]
107 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1)
108 sqrt_det = torch.sqrt(det)
110 trace = torch.diagonal(M, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
111 t = torch.sqrt(trace + 2 * sqrt_det)
113 sqrt_M = 1 / t * (M + sqrt_det * torch.eye(2, device=M.device).tile(N, 1, 1))
114 return sqrt_M
117def slow_inv_sqrt_2x2(M: torch.Tensor) -> torch.Tensor:
118 """
119 Computes the square root of the inverse of a tensor of shape [N, 2, 2]
121 Arguments:
122 M: a batch of 2x2 tensors to sqrt invert, i.e. a :math:`(B, 2, 2)` tensor
123 """
124 return sqrt_2x2(inv_2x2(M))
127def inv_sqrt_2x2(M: torch.Tensor) -> torch.Tensor:
128 """
129 Computes the square root of the inverse of a tensor of shape [N, 2, 2]
131 Arguments:
132 M: a batch of 2x2 tensors to sqrt invert, i.e. a :math:`(B, 2, 2)` tensor
133 """
134 N = M.shape[0]
135 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1)
136 sqrt_det = torch.sqrt(det)
138 trace = torch.diagonal(M, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
139 t = torch.sqrt(trace + 2 * sqrt_det)
141 M_adj = M.clone()
142 M_adj[:, 0, 0], M_adj[:, 1, 1] = M[:, 1, 1], M[:, 0, 0]
143 M_adj[:, 0, 1] *= -1
144 M_adj[:, 1, 0] *= -1
145 M_sqrt_inv = (
146 1 / t * (M_adj / sqrt_det + torch.eye(2, device=M.device).tile(N, 1, 1))
147 )
148 return M_sqrt_inv
151class _BatchNormNd(nn.Module):
152 r"""
153 BatchNorm for complex valued neural networks. The same code applies for
154 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be
155 (batch_size, features, d1, d2, ..)
157 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..`
158 vectors of size `features`.
160 As defined by Trabelsi et al. (2018)
162 Arguments:
163 num_features: :math:`C` from an expected input of size :math:`(B, C)`
164 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`.
165 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1`
166 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True`
167 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True`
168 cdtype: the dtype for complex numbers. Default torch.complex64
169 """
171 def __init__(
172 self,
173 num_features: int,
174 eps: float = 1e-5,
175 momentum: float = 0.1,
176 affine: bool = True,
177 track_running_stats: bool = True,
178 device: torch.device = None,
179 cdtype: torch.dtype = torch.complex64,
180 ) -> None:
181 super().__init__()
183 self.num_features = num_features
184 self.eps = eps
185 self.momentum = momentum
186 self.affine = affine
187 self.track_running_stats = track_running_stats
189 if self.affine:
190 self.weight = torch.nn.parameter.Parameter(
191 torch.empty((num_features, 2, 2), device=device)
192 )
193 self.bias = torch.nn.parameter.Parameter(
194 torch.empty((num_features,), device=device, dtype=cdtype)
195 )
196 else:
197 self.register_parameter("weight", None)
198 self.register_parameter("bias", None)
200 if self.track_running_stats:
201 # Register the running mean and running variance
202 # These will not be returned by model.parameters(), hence
203 # not updated by the optimizer although returned in the state_dict
204 # and therefore stored as model's assets
205 self.register_buffer(
206 "running_mean",
207 torch.zeros((num_features,), device=device, dtype=cdtype),
208 )
209 self.register_buffer(
210 "running_var", torch.ones((num_features, 2, 2), device=device)
211 )
212 self.register_buffer(
213 "num_batches_tracked", torch.tensor(0, dtype=torch.long, device=device)
214 )
215 else:
216 self.register_buffer("running_mean", None)
217 self.register_buffer("running_var", None)
218 self.register_buffer(
219 "num_batches_tracked",
220 None,
221 )
222 self.reset_parameters()
224 def reset_running_stats(self) -> None:
225 if self.track_running_stats:
226 self.running_mean.zero_()
227 self.running_var.zero_()
228 self.running_var[:, 0, 0] = 1 / math.sqrt(2.0)
229 self.running_var[:, 1, 1] = 1 / math.sqrt(2.0)
231 def reset_parameters(self) -> None:
232 with torch.no_grad():
233 self.reset_running_stats()
234 if self.affine:
235 # Initialize all the weights to zeros
236 init.zeros_(self.weight)
237 # And then fill in the diagonal with 1/sqrt(2)
238 # w is C, 2, 2
239 self.weight[:, 0, 0] = 1 / math.sqrt(2.0)
240 self.weight[:, 1, 1] = 1 / math.sqrt(2.0)
241 # Initialize all the biases to zero
242 init.zeros_(self.bias)
244 def forward(self, z: torch.Tensor) -> torch.Tensor:
245 # z : [B, C, d1, d2, ..] (complex)
246 batch_size = z.shape[0]
247 dim1 = z.shape[1]
248 other_dims = z.shape[2:]
250 xc = z.transpose(0, 1).reshape(self.num_features, -1) # num_features, BxHxW
252 if self.training or not self.track_running_stats:
253 # For training
254 # Or for testing but using the batch stats for centering/scaling
256 # Compute the means
257 mus = xc.mean(axis=-1) # num_features means
259 # Center the xc
260 xc_centered = xc - mus.unsqueeze(-1) # num_features, BxHxW
261 xc_centered = torch.view_as_real(xc_centered) # num_features, BxHxW, 2
263 # Transform the complex numbers as 2 reals to compute the variances and
264 # covariances
265 covs = batch_cov(xc_centered, centered=True) # 16 covariances matrices
266 else:
267 # The means come from the running stats
268 mus = self.running_mean
270 # Center the xc
271 xc_centered = xc - mus.unsqueeze(-1) # num_features, BxHxW
272 xc_centered = torch.view_as_real(xc_centered) # num_features, BxHxW, 2
274 # The variance/covariance come from the running stats
275 covs = self.running_var
277 # Invert the covariance to scale
278 invsqrt_covs = inv_sqrt_2x2(
279 covs + self.eps * torch.eye(2, device=covs.device)
280 ) # num_features, 2, 2
281 # Note: the xc_centered.transpose is to make
282 # xc_centered from (C, BxHxW, 2) to (B, 2, BxHxW)
283 # So that the batch matrix multiply works as expected
284 # where invsqrt_covs is (C, 2, 2)
285 outz = torch.bmm(invsqrt_covs, xc_centered.transpose(1, 2))
286 outz = outz.contiguous() # num_features, 2, BxHxW
288 # Shift by beta and scale by gamma
289 # weight is (num_features, 2, 2) real valued
290 outz = torch.bmm(self.weight, outz) # num_features, 2, BxHxW
291 outz = outz.transpose(1, 2).contiguous()
292 outz = torch.view_as_complex(outz) # num_features, BxHxW
294 # bias is (C, ) complex dtype
295 outz += self.bias.view((self.num_features, 1))
297 # With the following operation, weight
298 # outz = outz.reshape(C, B, H, W).transpose(0, 1)
299 outz = outz.reshape(dim1, batch_size, *other_dims).transpose(0, 1)
301 if self.training and self.track_running_stats:
302 self.running_mean = (
303 1.0 - self.momentum
304 ) * self.running_mean + self.momentum * mus
305 if torch.isnan(self.running_mean).any():
306 raise RuntimeError("Running mean divergence")
308 self.running_var = (
309 1.0 - self.momentum
310 ) * self.running_var + self.momentum * covs
311 if torch.isnan(self.running_var).any():
312 raise RuntimeError("Running var divergence")
313 return outz
316class BatchNorm1d(_BatchNormNd):
317 r"""
318 BatchNorm for complex valued neural networks. The same code applies for
319 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be
320 (batch_size, features, d1, d2, ..)
322 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..`
323 vectors of size `features`.
325 As defined by Trabelsi et al. (2018)
327 Arguments:
328 num_features: :math:`C` from an expected input of size :math:`(B, C)`
329 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`.
330 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1`
331 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True`
332 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True`
333 cdtype: the dtype for complex numbers. Default torch.complex64
334 """
336 pass
339class BatchNorm2d(_BatchNormNd):
340 r"""
341 BatchNorm for complex valued neural networks. The same code applies for
342 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be
343 (batch_size, features, d1, d2, ..)
345 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..`
346 vectors of size `features`.
348 As defined by Trabelsi et al. (2018)
350 Arguments:
351 num_features: :math:`C` from an expected input of size :math:`(B, C)`
352 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`.
353 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1`
354 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True`
355 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True`
356 cdtype: the dtype for complex numbers. Default torch.complex64
357 """
359 pass