Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Replace interpolate with resize #731

Merged
merged 3 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmseg/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.utils import _pair as to_2tuple

from mmseg.ops import resize
from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, swin_convert
Expand Down Expand Up @@ -746,7 +747,7 @@ def init_weights(self):
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = F.interpolate(
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/backbones/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.ops import Upsample
from ..builder import BACKBONES
from ..utils import UpConvBlock

Expand Down Expand Up @@ -203,7 +204,7 @@ def __init__(self,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
upsample = nn.Upsample(**upsample_cfg)
upsample = Upsample(**upsample_cfg)
if conv_first:
self.interp_upsample = nn.Sequential(conv, upsample)
else:
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init,
normal_init, trunc_normal_init)
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple

from mmseg.ops import resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed, vit_convert
Expand Down Expand Up @@ -373,7 +373,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
pos_embed_weight = pos_embed_weight.reshape(
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight = resize(
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/decode_heads/fpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import resize
from mmseg.ops import Upsample, resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, feature_strides, **kwargs):
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/setr_mla_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(self, mla_channels=128, up_scale=4, **kwargs):
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/setr_up_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer

from mmseg.ops import Upsample
from ..builder import HEADS
from .decode_head import BaseDecodeHead

Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self,
padding=int(kernel_size - 1) // 2,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Upsample(
Upsample(
scale_factor=up_scale,
mode='bilinear',
align_corners=self.align_corners)))
Expand Down
6 changes: 3 additions & 3 deletions mmseg/models/necks/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16

from mmseg.ops import resize
from ..builder import NECKS


Expand Down Expand Up @@ -173,11 +174,10 @@ def forward(self, inputs):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
laterals[i - 1] += resize(laterals[i], **self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i - 1] += resize(
laterals[i], size=prev_shape, **self.upsample_cfg)

# build outputs
Expand Down
4 changes: 2 additions & 2 deletions mmseg/models/necks/multilevel_neck.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init

from mmseg.ops import resize
from ..builder import NECKS


Expand Down Expand Up @@ -70,7 +70,7 @@ def forward(self, inputs):
inputs = [inputs[0] for _ in range(self.num_outs)]
outs = []
for i in range(self.num_outs):
x_resize = F.interpolate(
x_resize = resize(
inputs[i], scale_factor=self.scales[i], mode='bilinear')
outs.append(self.convs[i](x_resize))
return tuple(outs)
10 changes: 5 additions & 5 deletions tests/test_models/test_backbones/test_unet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import torch
from mmcv.cnn import ConvModule
from torch import nn

from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.ops import Upsample
from .utils import check_norm_state


Expand Down Expand Up @@ -145,7 +145,7 @@ def test_interp_conv():
block = InterpConv(64, 32, conv_first=False)
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])

Expand All @@ -154,7 +154,7 @@ def test_interp_conv():
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], ConvModule)
assert isinstance(block.interp_upsample[1], nn.Upsample)
assert isinstance(block.interp_upsample[1], Upsample)
assert x_out.shape == torch.Size([1, 32, 256, 256])

# test InterpConv with bilinear upsample for upsample 2X.
Expand All @@ -166,7 +166,7 @@ def test_interp_conv():
scale_factor=2, mode='bilinear', align_corners=False))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'bilinear'
Expand All @@ -179,7 +179,7 @@ def test_interp_conv():
upsample_cfg=dict(scale_factor=2, mode='nearest'))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[0], Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'nearest'
Expand Down
5 changes: 3 additions & 2 deletions tools/deploy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize


class ONNXRuntimeSegmentor(BaseSegmentor):
Expand Down Expand Up @@ -79,7 +80,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
Expand Down Expand Up @@ -127,7 +128,7 @@ def simple_test(self, img: torch.Tensor, img_meta: Iterable,
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred = resize(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
Expand Down
6 changes: 2 additions & 4 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor
from mmseg.ops import resize

torch.manual_seed(3)

Expand Down Expand Up @@ -210,10 +211,7 @@ def pytorch2onnx(model,

if dynamic_export and test_mode == 'whole':
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
img_list = [resize(_, scale_factor=1.5) for _ in img_list]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
Expand Down