diff --git a/mmedit/models/editors/rdn/rdn_net.py b/mmedit/models/editors/rdn/rdn_net.py index 29fa19ce70..24eeb912f8 100644 --- a/mmedit/models/editors/rdn/rdn_net.py +++ b/mmedit/models/editors/rdn/rdn_net.py @@ -15,6 +15,11 @@ class RDNNet(BaseModule): 'RDN-pytorch/blob/master/models.py' Copyright (c) 2021, JaeYun Yeo, under MIT License. + Most of the implementation follows the implementation in: + 'https://github.com/sanghyun-son/EDSR-PyTorch.git' + 'EDSR-PyTorch/blob/master/src/model/rdn.py' + Copyright (c) 2017, sanghyun-son, under MIT license. + Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. @@ -51,16 +56,15 @@ def __init__(self, mid_channels, mid_channels, kernel_size=3, padding=3 // 2) # residual dense blocks - self.rdbs = nn.ModuleList( - [RDB(self.mid_channels, self.channel_growth, self.num_layers)]) - for _ in range(self.num_blocks - 1): + self.rdbs = nn.ModuleList() + for _ in range(self.num_blocks): self.rdbs.append( - RDB(self.channel_growth, self.channel_growth, self.num_layers)) + RDB(self.mid_channels, self.channel_growth, self.num_layers)) # global feature fusion self.gff = nn.Sequential( nn.Conv2d( - self.channel_growth * self.num_blocks, + self.mid_channels * self.num_blocks, self.mid_channels, kernel_size=1), nn.Conv2d( @@ -165,7 +169,7 @@ def __init__(self, in_channels, channel_growth, num_layers): # local feature fusion self.lff = nn.Conv2d( in_channels + channel_growth * num_layers, - channel_growth, + in_channels, kernel_size=1) def forward(self, x): diff --git a/tests/test_models/test_editors/test_rdn/test_rdn_net.py b/tests/test_models/test_editors/test_rdn/test_rdn_net.py index ba84ed32ca..9e3b939be8 100644 --- a/tests/test_models/test_editors/test_rdn/test_rdn_net.py +++ b/tests/test_models/test_editors/test_rdn/test_rdn_net.py @@ -19,6 +19,7 @@ def test_rdn(): in_channels=3, out_channels=3, mid_channels=64, + channel_growth=32, num_blocks=16, upscale_factor=scale)