diff --git a/src/resnet.py b/src/resnet.py index 18641f0..33ce5ea 100644 --- a/src/resnet.py +++ b/src/resnet.py @@ -10,6 +10,8 @@ from torch import nn import torch.nn.functional as F +import torch +from typing import List, Type, Optional, Tuple class BasicBlock(nn.Module): @@ -18,7 +20,7 @@ class BasicBlock(nn.Module): """ expansion = 1 - def __init__(self, in_planes, planes, stride=1): + def __init__(self, in_planes:int, planes:int, stride:int=1) -> None: super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) @@ -36,7 +38,7 @@ def __init__(self, in_planes, planes, stride=1): nn.BatchNorm2d(self.expansion * planes), ) - def forward(self, x): + def forward(self, x:torch.Tensor) -> torch.Tensor: """ Forward pass for the BasicBlock. """ @@ -52,7 +54,7 @@ class Bottleneck(nn.Module): """ expansion = 4 - def __init__(self, in_planes, planes, stride=1): + def __init__(self, in_planes:int, planes:int, stride:int=1): super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) @@ -70,7 +72,7 @@ def __init__(self, in_planes, planes, stride=1): nn.BatchNorm2d(self.expansion * planes), ) - def forward(self, x): + def forward(self, x:torch.Tensor) -> torch.Tensor: """ Forward pass for the Bottleneck block. """ @@ -86,7 +88,7 @@ class ResNet(nn.Module): """ A ResNet model. """ - def __init__(self, block, num_blocks, num_classes=10, num_channels=3): + def __init__(self, block:Type[nn.Module], num_blocks:List[int], num_classes:int=10, num_channels:int=3)-> None: super().__init__() self.in_planes = 64 self.conv1 = nn.Conv2d( @@ -99,7 +101,7 @@ def __init__(self, block, num_blocks, num_classes=10, num_channels=3): self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) - def _make_layer(self, block, planes, num_blocks, stride): + def _make_layer(self, block:Type[nn.Module], planes:int, num_blocks:int, stride:int)-> nn.Sequential: strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: @@ -107,7 +109,7 @@ def _make_layer(self, block, planes, num_blocks, stride): self.in_planes = planes * block.expansion return nn.Sequential(*layers) - def forward(self, x, position=0, out_feature=False): + def forward(self, x:torch.Tensor, position:int=0, out_feature:bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass for the ResNet model. """ @@ -139,37 +141,37 @@ def forward(self, x, position=0, out_feature=False): return x, feature return x -def resnet10(num_channels=3, num_classes=10): +def resnet10(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-10 model. """ return ResNet(BasicBlock, [1, 1, 1, 1], num_classes, num_channels) -def resnet18(num_channels=3, num_classes=10): +def resnet18(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-18 model. """ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, num_channels) -def resnet34(num_channels=3, num_classes=10): +def resnet34(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-34 model. """ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, num_channels) -def resnet50(num_channels=3, num_classes=10): +def resnet50(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-50 model. """ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, num_channels) -def resnet101(num_channels=3, num_classes=10): +def resnet101(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-101 model. """ return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, num_channels) -def resnet152(num_channels=3, num_classes=10): +def resnet152(num_channels:int=3, num_classes:int=10) -> ResNet: """ Constructs a ResNet-152 model. """ diff --git a/src/resnet_in.py b/src/resnet_in.py index dca143b..f406438 100644 --- a/src/resnet_in.py +++ b/src/resnet_in.py @@ -8,6 +8,8 @@ from typing import Any from torch import nn from torch.hub import load_state_dict_from_url +import torch +from typing import List, Type, Optional __all__ = [ @@ -37,7 +39,7 @@ } -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): +def conv3x3(in_planes:int , out_planes:int, stride:int=1, groups:int=1, dilation:int=1) -> nn.Conv2d: """3x3 convolution with padding""" return nn.Conv2d( in_planes, @@ -51,7 +53,7 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): ) -def conv1x1(in_planes, out_planes, stride=1): +def conv1x1(in_planes:int, out_planes:int, stride:int=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) @@ -63,8 +65,8 @@ class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__(self, inplanes:int, planes:int, stride:int=1, downsample:Optional[nn.Sequential]=None, groups:int=1, + base_width:int=64, dilation:int=1, norm_layer:Optional[Type[nn.Module]]=None) -> None: super(BasicBlock, self).__init__() if norm_layer is None: @@ -82,7 +84,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x:torch.Tensor)-> torch.Tensor: """ Forward pass for the BasicBlock. """ @@ -106,8 +108,8 @@ class Bottleneck(nn.Module): expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, - base_width=64, dilation=1, norm_layer=None): + def __init__(self, inplanes:int, planes:int, stride:int=1, downsample:Optional[nn.Sequential]=None, groups:int=1, + base_width:int=64, dilation:int=1, norm_layer: Optional[Type[nn.Module]]=None): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -123,7 +125,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, self.downsample = downsample self.stride = stride - def forward(self, x): + def forward(self, x:torch.Tensor)-> torch.Tensor: """ Forward pass for the Bottleneck block. """ @@ -148,9 +150,9 @@ class ResNet(nn.Module): """ A ResNet model. """ - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, - groups=1, width_per_group=64, replace_stride_with_dilation=None, - norm_layer=None): + def __init__(self, block:Type[nn.Module], layers: List[int], num_classes:int=1000, zero_init_residual:bool=False, + groups:int=1, width_per_group:int=64, replace_stride_with_dilation:Optional[List[bool]]=None, + norm_layer:Optional[Type[nn.Module]]=None) -> None: super(ResNet, self).__init__() if norm_layer is None: @@ -206,11 +208,11 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) - def reset_fc(self, num_classes): + def reset_fc(self, num_classes:int) -> None: """Resets the fully connected layer with the specified number of classes.""" self.fc = nn.Linear(512 * self.block.expansion, num_classes) - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + def _make_layer(self, block:Type[nn.Module], planes:int, blocks:int, stride:int=1, dilate:bool=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -244,7 +246,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) - def _forward_impl(self, x, position, return_features): + def _forward_impl(self, x:torch.Tensor, position:int, return_features:bool) -> torch.Tensor: if position == 0: x = self.conv1(x) x = self.bn1(x) @@ -267,20 +269,20 @@ def _forward_impl(self, x, position, return_features): return x, feat return x - def forward(self, x, position, return_features=False): + def forward(self, x:torch.Tensor, position:int, return_features:bool=False) -> torch.Tensor: """Forward pass for the ResNet model.""" return self._forward_impl(x, position, return_features=return_features) def _resnet( - arch, - block, - layers, - pretrained, - progress, - num_channels=3, - num_classes=1000, - **kwargs + arch:str, + block:Type[nn.Module], + layers:List[int], + pretrained:bool, + progress:bool, + num_channels:int=3, + num_classes:int=1000, + **kwargs: Any ): if num_channels != 3: raise ValueError("Only 3 channels supported for now") @@ -298,7 +300,7 @@ def _resnet( return model -def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any): +def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -310,7 +312,7 @@ def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any): -def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any): +def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any) -> ResNet: r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -322,7 +324,7 @@ def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any): -def resnet50(pretrained=False, progress=True, **kwargs): +def resnet50(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -334,7 +336,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): -def resnet101(pretrained=False, progress=True, **kwargs): +def resnet101(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -346,7 +348,7 @@ def resnet101(pretrained=False, progress=True, **kwargs): -def resnet152(pretrained=False, progress=True, **kwargs): +def resnet152(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -358,7 +360,7 @@ def resnet152(pretrained=False, progress=True, **kwargs): -def resnext50_32x4d(pretrained=False, progress=True, **kwargs): +def resnext50_32x4d(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: @@ -372,7 +374,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): -def resnext101_32x8d(pretrained=False, progress=True, **kwargs): +def resnext101_32x8d(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: @@ -386,7 +388,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): -def wide_resnet50_2(pretrained=False, progress=True, **kwargs): +def wide_resnet50_2(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ The model is the same as ResNet except for the bottleneck number of channels @@ -403,7 +405,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): -def wide_resnet101_2(pretrained=False, progress=True, **kwargs): +def wide_resnet101_2(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet: r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_ The model is the same as ResNet except for the bottleneck number of channels