Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/pooling.py: 66%

38 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29from torch.nn.common_types import _size_2_t 

30 

31 

32class MaxPool2d(nn.Module): 

33 r"""Applies a 2D max pooling on the module of the input signal 

34 

35 In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` `(kH, kW)` can be precisely described as: 

36 

37 .. math:: 

38 

39 \begin{aligned} 

40 out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\ 

41 & \text{|input|}(N_i, C_j, \text{stride[0]} \times h + m, 

42 \text{stride[1]} \times w + n) 

43 \end{aligned} 

44 

45 Internally, it is relying on the :external:py:class:`torch.nn.MaxPool2d` 

46  

47 Arguments: 

48 kernel_size: thr size of the window to take a max over 

49 stride: the stride of the window 

50 padding: implicit negative infinity padding to be added 

51 dilation: a parameter that controls the stride of elements in the window 

52 ceil_mode: when `True`, use `ceil` instead of `floor` to compute the output shape 

53 return_indices: if `True`, will return the max indices along with the outputs 

54 

55 """ 

56 

57 def __init__( 

58 self, 

59 kernel_size: _size_2_t, 

60 stride: Optional[_size_2_t] = None, 

61 padding: _size_2_t = 0, 

62 dilation: _size_2_t = 1, 

63 ceil_mode: bool = False, 

64 return_indices: bool = False, 

65 ) -> None: 

66 super().__init__() 

67 self.return_indices = return_indices 

68 self.m = nn.MaxPool2d( 

69 kernel_size, 

70 stride, 

71 padding, 

72 dilation, 

73 ceil_mode=ceil_mode, 

74 return_indices=True, 

75 ) 

76 

77 def forward(self, z: torch.Tensor) -> torch.Tensor: 

78 """ 

79 Computes and return the MaxPool over the magnitude of the input 

80 """ 

81 _, indices = self.m(torch.abs(z)) 

82 

83 if self.return_indices: 

84 return z.flatten()[indices], indices 

85 else: 

86 return z.flatten()[indices] 

87 

88 

89class AvgPool2d(nn.Module): 

90 """ 

91 Implementation of torch.nn.AvgPool2d for complex numbers. 

92 Apply AvgPool2d on the real and imaginary part. 

93 Returns complex values associated to the AvgPool2d results. 

94 

95 Arguments: 

96 kernel_size: thr size of the window to compute the average 

97 stride: the stride of the window 

98 padding: implicit negative infinity padding to be added 

99 ceil_mode: when `True`, use `ceil` instead of `floor` to compute the output shape 

100 count_include_pad: when `True`, will include the zero-padding in the averaging calculation 

101 divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. 

102 """ 

103 

104 def __init__( 

105 self, 

106 kernel_size: _size_2_t, 

107 stride: Optional[_size_2_t] = None, 

108 padding: _size_2_t = 0, 

109 ceil_mode: bool = False, 

110 count_include_pad: bool = True, 

111 divisor_override: Optional[int] = None, 

112 ) -> None: 

113 super().__init__() 

114 if type(kernel_size) == int: 

115 self.kernel_size = [kernel_size] * 2 + [1] 

116 elif type(kernel_size) == tuple: 

117 if len(kernel_size) < 3: 

118 self.kernel_size = [kernel_size] + [1] 

119 else: 

120 self.kernel_size = kernel_size 

121 

122 if type(stride) == int: 

123 self.stride = [stride] * 2 + [1] 

124 elif type(stride) == tuple: 

125 if len(stride) < 3: 

126 self.stride = [stride] + [1] 

127 else: 

128 self.stride = stride 

129 

130 if type(padding) == int: 

131 self.padding = [padding] * 2 + [0] 

132 elif type(padding) == tuple: 

133 if len(padding) < 3: 

134 self.padding = [padding] + [0] 

135 else: 

136 self.padding = padding 

137 

138 self.m = torch.nn.AvgPool3d( 

139 self.kernel_size, 

140 self.stride, 

141 self.padding, 

142 ceil_mode, 

143 count_include_pad, 

144 divisor_override, 

145 ) 

146 

147 def forward(self, z: torch.Tensor) -> torch.Tensor: 

148 """ 

149 Computes the average over the real and imaginery parts. 

150 """ 

151 return torch.view_as_complex(self.m(torch.view_as_real(z)))