Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/batchnorm.py: 83%

112 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24import math 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29import torch.nn.init as init 

30 

31 

32def batch_cov(points: torch.Tensor, centered: bool = False) -> torch.Tensor: 

33 """ 

34 Batched covariance computation 

35 Adapted from : https://stackoverflow.com/a/71357620/2164582 

36 

37 Arguments: 

38 points: the (B, N, D) input tensor from which to compute B covariances 

39 centered: If `True`, assumes for every batch, the N vectors are centered. default: `False` 

40 

41 Returns: 

42 bcov: the covariances as a `(B, D, D)` tensor 

43 """ 

44 B, N, D = points.size() 

45 if not centered: 

46 mean = points.mean(dim=1).unsqueeze(1) 

47 diffs = (points - mean).reshape(B * N, D) 

48 else: 

49 diffs = points.reshape(B * N, D) 

50 prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D) 

51 bcov = prods.sum(dim=1) / (N - 1) # Unbiased estimate 

52 return bcov # (B, D, D) 

53 

54 

55def inv_2x2(M: torch.Tensor) -> torch.Tensor: 

56 r""" 

57 Computes the inverse of a tensor of shape [N, 2, 2]. 

58 

59 If we denote 

60 

61 .. math:: 

62 

63 M = \begin{pmatrix} a & b \\ c & d \end{pmatrix} 

64 

65 The inverse is given by 

66 

67 .. math:: 

68 

69 M^{-1} = \frac{1}{Det M} Adj(M) = \frac{1}{ad - bc}\begin{pmatrix}d & -b \\ -c & a\end{pmatrix} 

70 

71 Arguments: 

72 M: a batch of 2x2 tensors to invert, i.e. a :math:`(B, 2, 2)` tensor 

73 """ 

74 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1) 

75 

76 M_adj = M.clone() 

77 M_adj[:, 0, 0], M_adj[:, 1, 1] = M[:, 1, 1], M[:, 0, 0] 

78 M_adj[:, 0, 1] *= -1 

79 M_adj[:, 1, 0] *= -1 

80 M_inv = 1 / det * M_adj 

81 return M_inv 

82 

83 

84def sqrt_2x2(M: torch.Tensor) -> torch.Tensor: 

85 r""" 

86 Computes the square root of the tensor of shape [N, 2, 2]. 

87 

88 If we denote 

89 

90 .. math:: 

91 

92 M = \begin{pmatrix} a & b \\ c & d \end{pmatrix} 

93 

94 The square root is given by : 

95 

96 .. math:: 

97 

98 \begin{align} 

99 \sqrt{M} &= \frac{1}{t} ( M + \sqrt{Det M} I)\\ 

100 t &= \sqrt{Tr M + 2 \sqrt{Det M}} 

101 \end{align} 

102 

103 Arguments: 

104 M: a batch of 2x2 tensors to invert, i.e. a :math:`(B, 2, 2)` tensor 

105 """ 

106 N = M.shape[0] 

107 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1) 

108 sqrt_det = torch.sqrt(det) 

109 

110 trace = torch.diagonal(M, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1).unsqueeze(-1) 

111 t = torch.sqrt(trace + 2 * sqrt_det) 

112 

113 sqrt_M = 1 / t * (M + sqrt_det * torch.eye(2, device=M.device).tile(N, 1, 1)) 

114 return sqrt_M 

115 

116 

117def slow_inv_sqrt_2x2(M: torch.Tensor) -> torch.Tensor: 

118 """ 

119 Computes the square root of the inverse of a tensor of shape [N, 2, 2] 

120 

121 Arguments: 

122 M: a batch of 2x2 tensors to sqrt invert, i.e. a :math:`(B, 2, 2)` tensor 

123 """ 

124 return sqrt_2x2(inv_2x2(M)) 

125 

126 

127def inv_sqrt_2x2(M: torch.Tensor) -> torch.Tensor: 

128 """ 

129 Computes the square root of the inverse of a tensor of shape [N, 2, 2] 

130 

131 Arguments: 

132 M: a batch of 2x2 tensors to sqrt invert, i.e. a :math:`(B, 2, 2)` tensor 

133 """ 

134 N = M.shape[0] 

135 det = torch.linalg.det(M).unsqueeze(-1).unsqueeze(-1) 

136 sqrt_det = torch.sqrt(det) 

137 

138 trace = torch.diagonal(M, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1).unsqueeze(-1) 

139 t = torch.sqrt(trace + 2 * sqrt_det) 

140 

141 M_adj = M.clone() 

142 M_adj[:, 0, 0], M_adj[:, 1, 1] = M[:, 1, 1], M[:, 0, 0] 

143 M_adj[:, 0, 1] *= -1 

144 M_adj[:, 1, 0] *= -1 

145 M_sqrt_inv = ( 

146 1 / t * (M_adj / sqrt_det + torch.eye(2, device=M.device).tile(N, 1, 1)) 

147 ) 

148 return M_sqrt_inv 

149 

150 

151class _BatchNormNd(nn.Module): 

152 r""" 

153 BatchNorm for complex valued neural networks. The same code applies for 

154 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be 

155 (batch_size, features, d1, d2, ..) 

156 

157 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..` 

158 vectors of size `features`. 

159 

160 As defined by Trabelsi et al. (2018) 

161 

162 Arguments: 

163 num_features: :math:`C` from an expected input of size :math:`(B, C)` 

164 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`. 

165 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1` 

166 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True` 

167 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True` 

168 cdtype: the dtype for complex numbers. Default torch.complex64 

169 """ 

170 

171 def __init__( 

172 self, 

173 num_features: int, 

174 eps: float = 1e-5, 

175 momentum: float = 0.1, 

176 affine: bool = True, 

177 track_running_stats: bool = True, 

178 device: torch.device = None, 

179 cdtype: torch.dtype = torch.complex64, 

180 ) -> None: 

181 super().__init__() 

182 

183 self.num_features = num_features 

184 self.eps = eps 

185 self.momentum = momentum 

186 self.affine = affine 

187 self.track_running_stats = track_running_stats 

188 

189 if self.affine: 

190 self.weight = torch.nn.parameter.Parameter( 

191 torch.empty((num_features, 2, 2), device=device) 

192 ) 

193 self.bias = torch.nn.parameter.Parameter( 

194 torch.empty((num_features,), device=device, dtype=cdtype) 

195 ) 

196 else: 

197 self.register_parameter("weight", None) 

198 self.register_parameter("bias", None) 

199 

200 if self.track_running_stats: 

201 # Register the running mean and running variance 

202 # These will not be returned by model.parameters(), hence 

203 # not updated by the optimizer although returned in the state_dict 

204 # and therefore stored as model's assets 

205 self.register_buffer( 

206 "running_mean", 

207 torch.zeros((num_features,), device=device, dtype=cdtype), 

208 ) 

209 self.register_buffer( 

210 "running_var", torch.ones((num_features, 2, 2), device=device) 

211 ) 

212 self.register_buffer( 

213 "num_batches_tracked", torch.tensor(0, dtype=torch.long, device=device) 

214 ) 

215 else: 

216 self.register_buffer("running_mean", None) 

217 self.register_buffer("running_var", None) 

218 self.register_buffer( 

219 "num_batches_tracked", 

220 None, 

221 ) 

222 self.reset_parameters() 

223 

224 def reset_running_stats(self) -> None: 

225 if self.track_running_stats: 

226 self.running_mean.zero_() 

227 self.running_var.zero_() 

228 self.running_var[:, 0, 0] = 1 / math.sqrt(2.0) 

229 self.running_var[:, 1, 1] = 1 / math.sqrt(2.0) 

230 

231 def reset_parameters(self) -> None: 

232 with torch.no_grad(): 

233 self.reset_running_stats() 

234 if self.affine: 

235 # Initialize all the weights to zeros 

236 init.zeros_(self.weight) 

237 # And then fill in the diagonal with 1/sqrt(2) 

238 # w is C, 2, 2 

239 self.weight[:, 0, 0] = 1 / math.sqrt(2.0) 

240 self.weight[:, 1, 1] = 1 / math.sqrt(2.0) 

241 # Initialize all the biases to zero 

242 init.zeros_(self.bias) 

243 

244 def forward(self, z: torch.Tensor) -> torch.Tensor: 

245 # z : [B, C, d1, d2, ..] (complex) 

246 batch_size = z.shape[0] 

247 dim1 = z.shape[1] 

248 other_dims = z.shape[2:] 

249 

250 xc = z.transpose(0, 1).reshape(self.num_features, -1) # num_features, BxHxW 

251 

252 if self.training or not self.track_running_stats: 

253 # For training 

254 # Or for testing but using the batch stats for centering/scaling 

255 

256 # Compute the means 

257 mus = xc.mean(axis=-1) # num_features means 

258 

259 # Center the xc 

260 xc_centered = xc - mus.unsqueeze(-1) # num_features, BxHxW 

261 xc_centered = torch.view_as_real(xc_centered) # num_features, BxHxW, 2 

262 

263 # Transform the complex numbers as 2 reals to compute the variances and 

264 # covariances 

265 covs = batch_cov(xc_centered, centered=True) # 16 covariances matrices 

266 else: 

267 # The means come from the running stats 

268 mus = self.running_mean 

269 

270 # Center the xc 

271 xc_centered = xc - mus.unsqueeze(-1) # num_features, BxHxW 

272 xc_centered = torch.view_as_real(xc_centered) # num_features, BxHxW, 2 

273 

274 # The variance/covariance come from the running stats 

275 covs = self.running_var 

276 

277 # Invert the covariance to scale 

278 invsqrt_covs = inv_sqrt_2x2( 

279 covs + self.eps * torch.eye(2, device=covs.device) 

280 ) # num_features, 2, 2 

281 # Note: the xc_centered.transpose is to make 

282 # xc_centered from (C, BxHxW, 2) to (B, 2, BxHxW) 

283 # So that the batch matrix multiply works as expected 

284 # where invsqrt_covs is (C, 2, 2) 

285 outz = torch.bmm(invsqrt_covs, xc_centered.transpose(1, 2)) 

286 outz = outz.contiguous() # num_features, 2, BxHxW 

287 

288 # Shift by beta and scale by gamma 

289 # weight is (num_features, 2, 2) real valued 

290 outz = torch.bmm(self.weight, outz) # num_features, 2, BxHxW 

291 outz = outz.transpose(1, 2).contiguous() 

292 outz = torch.view_as_complex(outz) # num_features, BxHxW 

293 

294 # bias is (C, ) complex dtype 

295 outz += self.bias.view((self.num_features, 1)) 

296 

297 # With the following operation, weight 

298 # outz = outz.reshape(C, B, H, W).transpose(0, 1) 

299 outz = outz.reshape(dim1, batch_size, *other_dims).transpose(0, 1) 

300 

301 if self.training and self.track_running_stats: 

302 self.running_mean = ( 

303 1.0 - self.momentum 

304 ) * self.running_mean + self.momentum * mus 

305 if torch.isnan(self.running_mean).any(): 

306 raise RuntimeError("Running mean divergence") 

307 

308 self.running_var = ( 

309 1.0 - self.momentum 

310 ) * self.running_var + self.momentum * covs 

311 if torch.isnan(self.running_var).any(): 

312 raise RuntimeError("Running var divergence") 

313 return outz 

314 

315 

316class BatchNorm1d(_BatchNormNd): 

317 r""" 

318 BatchNorm for complex valued neural networks. The same code applies for 

319 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be 

320 (batch_size, features, d1, d2, ..) 

321 

322 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..` 

323 vectors of size `features`. 

324 

325 As defined by Trabelsi et al. (2018) 

326 

327 Arguments: 

328 num_features: :math:`C` from an expected input of size :math:`(B, C)` 

329 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`. 

330 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1` 

331 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True` 

332 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True` 

333 cdtype: the dtype for complex numbers. Default torch.complex64 

334 """ 

335 

336 pass 

337 

338 

339class BatchNorm2d(_BatchNormNd): 

340 r""" 

341 BatchNorm for complex valued neural networks. The same code applies for 

342 BatchNorm1d, BatchNorm2d, the only condition being the input tensor must be 

343 (batch_size, features, d1, d2, ..) 

344 

345 The statistics will be computed over the :math:`batch\_size \times d_1 \times d_2 \times ..` 

346 vectors of size `features`. 

347 

348 As defined by Trabelsi et al. (2018) 

349 

350 Arguments: 

351 num_features: :math:`C` from an expected input of size :math:`(B, C)` 

352 eps: a value added to the denominator for numerical stability. Default :math:`1e-5`. 

353 momentum: the value used for the running mean and running var computation. Can be set to `None` for cumulative moving average (i.e. simple average). Default: :math:`0.1` 

354 affine: a boolean value that when set to `True`, this module has learnable affine parameters. Default: `True` 

355 track_running_stats: a boolean value that when set to `True`, this module tracks the running mean and variance, and when set to`False`, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: `True` 

356 cdtype: the dtype for complex numbers. Default torch.complex64 

357 """ 

358 

359 pass