Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/pooling.py: 66%
38 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-13 08:53 +0000
1# MIT License
3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix
5# Permission is hereby granted, free of charge, to any person obtaining a copy
6# of this software and associated documentation files (the "Software"), to deal
7# in the Software without restriction, including without limitation the rights
8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9# copies of the Software, and to permit persons to whom the Software is
10# furnished to do so, subject to the following conditions:
12# The above copyright notice and this permission notice shall be included in
13# all copies or substantial portions of the Software.
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23# Standard imports
24from typing import Optional
26# External imports
27import torch
28import torch.nn as nn
29from torch.nn.common_types import _size_2_t
32class MaxPool2d(nn.Module):
33 r"""Applies a 2D max pooling on the module of the input signal
35 In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`, output :math:`(N, C, H_{out}, W_{out})` and `kernel_size` `(kH, kW)` can be precisely described as:
37 .. math::
39 \begin{aligned}
40 out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
41 & \text{|input|}(N_i, C_j, \text{stride[0]} \times h + m,
42 \text{stride[1]} \times w + n)
43 \end{aligned}
45 Internally, it is relying on the :external:py:class:`torch.nn.MaxPool2d`
47 Arguments:
48 kernel_size: thr size of the window to take a max over
49 stride: the stride of the window
50 padding: implicit negative infinity padding to be added
51 dilation: a parameter that controls the stride of elements in the window
52 ceil_mode: when `True`, use `ceil` instead of `floor` to compute the output shape
53 return_indices: if `True`, will return the max indices along with the outputs
55 """
57 def __init__(
58 self,
59 kernel_size: _size_2_t,
60 stride: Optional[_size_2_t] = None,
61 padding: _size_2_t = 0,
62 dilation: _size_2_t = 1,
63 ceil_mode: bool = False,
64 return_indices: bool = False,
65 ) -> None:
66 super().__init__()
67 self.return_indices = return_indices
68 self.m = nn.MaxPool2d(
69 kernel_size,
70 stride,
71 padding,
72 dilation,
73 ceil_mode=ceil_mode,
74 return_indices=True,
75 )
77 def forward(self, z: torch.Tensor) -> torch.Tensor:
78 """
79 Computes and return the MaxPool over the magnitude of the input
80 """
81 _, indices = self.m(torch.abs(z))
83 if self.return_indices:
84 return z.flatten()[indices], indices
85 else:
86 return z.flatten()[indices]
89class AvgPool2d(nn.Module):
90 """
91 Implementation of torch.nn.AvgPool2d for complex numbers.
92 Apply AvgPool2d on the real and imaginary part.
93 Returns complex values associated to the AvgPool2d results.
95 Arguments:
96 kernel_size: thr size of the window to compute the average
97 stride: the stride of the window
98 padding: implicit negative infinity padding to be added
99 ceil_mode: when `True`, use `ceil` instead of `floor` to compute the output shape
100 count_include_pad: when `True`, will include the zero-padding in the averaging calculation
101 divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used.
102 """
104 def __init__(
105 self,
106 kernel_size: _size_2_t,
107 stride: Optional[_size_2_t] = None,
108 padding: _size_2_t = 0,
109 ceil_mode: bool = False,
110 count_include_pad: bool = True,
111 divisor_override: Optional[int] = None,
112 ) -> None:
113 super().__init__()
114 if type(kernel_size) == int:
115 self.kernel_size = [kernel_size] * 2 + [1]
116 elif type(kernel_size) == tuple:
117 if len(kernel_size) < 3:
118 self.kernel_size = [kernel_size] + [1]
119 else:
120 self.kernel_size = kernel_size
122 if type(stride) == int:
123 self.stride = [stride] * 2 + [1]
124 elif type(stride) == tuple:
125 if len(stride) < 3:
126 self.stride = [stride] + [1]
127 else:
128 self.stride = stride
130 if type(padding) == int:
131 self.padding = [padding] * 2 + [0]
132 elif type(padding) == tuple:
133 if len(padding) < 3:
134 self.padding = [padding] + [0]
135 else:
136 self.padding = padding
138 self.m = torch.nn.AvgPool3d(
139 self.kernel_size,
140 self.stride,
141 self.padding,
142 ceil_mode,
143 count_include_pad,
144 divisor_override,
145 )
147 def forward(self, z: torch.Tensor) -> torch.Tensor:
148 """
149 Computes the average over the real and imaginery parts.
150 """
151 return torch.view_as_complex(self.m(torch.view_as_real(z)))