Skip to content

Commit

Permalink
SONAR Onboarding (#73)
Browse files Browse the repository at this point in the history
* Fixed type errors and implemented type hinting in resnet_in

* Fixing type errors and some type hinting in resnet

---------

Co-authored-by: Samuel Tesfai <samueltesfai@dhcp-10-29-98-194.dyn.mit.edu>
  • Loading branch information
samuelt0 and Samuel Tesfai authored Sep 9, 2024
1 parent cd8531e commit af9030d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
28 changes: 15 additions & 13 deletions src/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from torch import nn
import torch.nn.functional as F
import torch
from typing import List, Type, Optional, Tuple


class BasicBlock(nn.Module):
Expand All @@ -18,7 +20,7 @@ class BasicBlock(nn.Module):
"""
expansion = 1

def __init__(self, in_planes, planes, stride=1):
def __init__(self, in_planes:int, planes:int, stride:int=1) -> None:

super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
Expand All @@ -36,7 +38,7 @@ def __init__(self, in_planes, planes, stride=1):
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x):
def forward(self, x:torch.Tensor) -> torch.Tensor:
"""
Forward pass for the BasicBlock.
"""
Expand All @@ -52,7 +54,7 @@ class Bottleneck(nn.Module):
"""
expansion = 4

def __init__(self, in_planes, planes, stride=1):
def __init__(self, in_planes:int, planes:int, stride:int=1):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
Expand All @@ -70,7 +72,7 @@ def __init__(self, in_planes, planes, stride=1):
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x):
def forward(self, x:torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Bottleneck block.
"""
Expand All @@ -86,7 +88,7 @@ class ResNet(nn.Module):
"""
A ResNet model.
"""
def __init__(self, block, num_blocks, num_classes=10, num_channels=3):
def __init__(self, block:Type[nn.Module], num_blocks:List[int], num_classes:int=10, num_channels:int=3)-> None:
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(
Expand All @@ -99,15 +101,15 @@ def __init__(self, block, num_blocks, num_classes=10, num_channels=3):
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
def _make_layer(self, block:Type[nn.Module], planes:int, num_blocks:int, stride:int)-> nn.Sequential:
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x, position=0, out_feature=False):
def forward(self, x:torch.Tensor, position:int=0, out_feature:bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass for the ResNet model.
"""
Expand Down Expand Up @@ -139,37 +141,37 @@ def forward(self, x, position=0, out_feature=False):
return x, feature
return x

def resnet10(num_channels=3, num_classes=10):
def resnet10(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-10 model.
"""
return ResNet(BasicBlock, [1, 1, 1, 1], num_classes, num_channels)

def resnet18(num_channels=3, num_classes=10):
def resnet18(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-18 model.
"""
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, num_channels)

def resnet34(num_channels=3, num_classes=10):
def resnet34(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-34 model.
"""
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, num_channels)

def resnet50(num_channels=3, num_classes=10):
def resnet50(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-50 model.
"""
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, num_channels)

def resnet101(num_channels=3, num_classes=10):
def resnet101(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-101 model.
"""
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, num_channels)

def resnet152(num_channels=3, num_classes=10):
def resnet152(num_channels:int=3, num_classes:int=10) -> ResNet:
"""
Constructs a ResNet-152 model.
"""
Expand Down
66 changes: 34 additions & 32 deletions src/resnet_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Any
from torch import nn
from torch.hub import load_state_dict_from_url
import torch
from typing import List, Type, Optional


__all__ = [
Expand Down Expand Up @@ -37,7 +39,7 @@
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
def conv3x3(in_planes:int , out_planes:int, stride:int=1, groups:int=1, dilation:int=1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
Expand All @@ -51,7 +53,7 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
)


def conv1x1(in_planes, out_planes, stride=1):
def conv1x1(in_planes:int, out_planes:int, stride:int=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

Expand All @@ -63,8 +65,8 @@ class BasicBlock(nn.Module):
expansion = 1


def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
def __init__(self, inplanes:int, planes:int, stride:int=1, downsample:Optional[nn.Sequential]=None, groups:int=1,
base_width:int=64, dilation:int=1, norm_layer:Optional[Type[nn.Module]]=None) -> None:
super(BasicBlock, self).__init__()

if norm_layer is None:
Expand All @@ -82,7 +84,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x:torch.Tensor)-> torch.Tensor:
"""
Forward pass for the BasicBlock.
"""
Expand All @@ -106,8 +108,8 @@ class Bottleneck(nn.Module):
expansion = 4


def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
def __init__(self, inplanes:int, planes:int, stride:int=1, downsample:Optional[nn.Sequential]=None, groups:int=1,
base_width:int=64, dilation:int=1, norm_layer: Optional[Type[nn.Module]]=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand All @@ -123,7 +125,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
self.downsample = downsample
self.stride = stride

def forward(self, x):
def forward(self, x:torch.Tensor)-> torch.Tensor:
"""
Forward pass for the Bottleneck block.
"""
Expand All @@ -148,9 +150,9 @@ class ResNet(nn.Module):
"""
A ResNet model.
"""
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
def __init__(self, block:Type[nn.Module], layers: List[int], num_classes:int=1000, zero_init_residual:bool=False,
groups:int=1, width_per_group:int=64, replace_stride_with_dilation:Optional[List[bool]]=None,
norm_layer:Optional[Type[nn.Module]]=None) -> None:
super(ResNet, self).__init__()

if norm_layer is None:
Expand Down Expand Up @@ -206,11 +208,11 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def reset_fc(self, num_classes):
def reset_fc(self, num_classes:int) -> None:
"""Resets the fully connected layer with the specified number of classes."""
self.fc = nn.Linear(512 * self.block.expansion, num_classes)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
def _make_layer(self, block:Type[nn.Module], planes:int, blocks:int, stride:int=1, dilate:bool=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
Expand Down Expand Up @@ -244,7 +246,7 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):

return nn.Sequential(*layers)

def _forward_impl(self, x, position, return_features):
def _forward_impl(self, x:torch.Tensor, position:int, return_features:bool) -> torch.Tensor:
if position == 0:
x = self.conv1(x)
x = self.bn1(x)
Expand All @@ -267,20 +269,20 @@ def _forward_impl(self, x, position, return_features):
return x, feat
return x

def forward(self, x, position, return_features=False):
def forward(self, x:torch.Tensor, position:int, return_features:bool=False) -> torch.Tensor:
"""Forward pass for the ResNet model."""
return self._forward_impl(x, position, return_features=return_features)


def _resnet(
arch,
block,
layers,
pretrained,
progress,
num_channels=3,
num_classes=1000,
**kwargs
arch:str,
block:Type[nn.Module],
layers:List[int],
pretrained:bool,
progress:bool,
num_channels:int=3,
num_classes:int=1000,
**kwargs: Any
):
if num_channels != 3:
raise ValueError("Only 3 channels supported for now")
Expand All @@ -298,7 +300,7 @@ def _resnet(
return model


def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any):
def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -310,7 +312,7 @@ def resnet18(pretrained:bool=False, progress:bool=True, **kwargs: Any):



def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any):
def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -322,7 +324,7 @@ def resnet34(pretrained:bool=False, progress: bool=True, **kwargs: Any):



def resnet50(pretrained=False, progress=True, **kwargs):
def resnet50(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -334,7 +336,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):



def resnet101(pretrained=False, progress=True, **kwargs):
def resnet101(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -346,7 +348,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):



def resnet152(pretrained=False, progress=True, **kwargs):
def resnet152(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -358,7 +360,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):



def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
def resnext50_32x4d(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
Expand All @@ -372,7 +374,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):



def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
def resnext101_32x8d(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
Expand All @@ -386,7 +388,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):



def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
def wide_resnet50_2(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
Expand All @@ -403,7 +405,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):



def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def wide_resnet101_2(pretrained:bool=False, progress:bool=True, **kwargs:Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
Expand Down

0 comments on commit af9030d

Please sign in to comment.