Coverage for /home/runner/work/torchcvnn/torchcvnn/src/torchcvnn/nn/modules/dropout.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-13 08:53 +0000

1# MIT License 

2 

3# Copyright (c) 2023 Jérémie Levi, Victor Dhédin, Jeremy Fix 

4 

5# Permission is hereby granted, free of charge, to any person obtaining a copy 

6# of this software and associated documentation files (the "Software"), to deal 

7# in the Software without restriction, including without limitation the rights 

8# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

9# copies of the Software, and to permit persons to whom the Software is 

10# furnished to do so, subject to the following conditions: 

11 

12# The above copyright notice and this permission notice shall be included in 

13# all copies or substantial portions of the Software. 

14 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23# Standard imports 

24from typing import Optional 

25 

26# External imports 

27import torch 

28import torch.nn as nn 

29from torch.nn.common_types import _size_2_t 

30 

31 

32class Dropout(nn.Module): 

33 r""" 

34 Applies dropout to zero out values of the inputs 

35 

36 During training, randomly zeroes some of the elements of the input tensor 

37 with probability :math:`p` using samples from a Bernouilli distribution. Each 

38 channel will be zeroed out independently on every forward call. 

39 

40 Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during 

41 training. This means that during evaluation the module simply computes an 

42 identity function. 

43 

44 Note: 

45 As for now, with torch 2.1.0 :external:py:class:`torch.nn.Dropout` cannot be applied on 

46 complex valued tensors. However, our implementation relies on the torch 

47 implementation by dropout out a tensor of ones used as a mask on the 

48 input. 

49 

50 Arguments: 

51 p: probability of an element to be zeroed. 

52 """ 

53 

54 def __init__(self, p: float = 0.5): 

55 super().__init__() 

56 self.p = p 

57 

58 def forward(self, z: torch.Tensor) -> torch.Tensor: 

59 mask = torch.nn.functional.dropout( 

60 torch.ones(z.shape), self.p, training=self.training 

61 ).to(z.device) 

62 return mask * z 

63 

64 

65class Dropout2d(nn.Module): 

66 r""" 

67 Applies dropout to zero out complete channels 

68 

69 Note: 

70 As for now, with torch 2.1.0 :external:py:class:`torch.nn.Dropout2d` cannot be applied on 

71 complex valued tensors. However, our implementation relies on the torch 

72 implementation by dropout out a tensor of ones used as a mask on the 

73 input. 

74 

75 Arguments: 

76 p: probability of an element to be zeroed. 

77 """ 

78 

79 def __init__(self, p: float = 0.5): 

80 super().__init__() 

81 self.p = p 

82 

83 def forward(self, z: torch.Tensor) -> torch.Tensor: 

84 mask = torch.nn.functional.dropout2d( 

85 torch.ones(z.shape), self.p, training=self.training 

86 ).to(z.device) 

87 return mask * z