Skip to content

Commit

Permalink
Add with_pool args for vgg (#28684)
Browse files Browse the repository at this point in the history
* add arg for vgg
  • Loading branch information
LielinJiang authored Nov 18, 2020
1 parent 532e4bb commit 01a14e1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python/paddle/vision/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def forward(self, x):
x = self.layer3(x)
x = self.layer4(x)

if self.with_pool > 0:
if self.with_pool:
x = self.avgpool(x)

if self.num_classes > 0:
Expand Down
42 changes: 27 additions & 15 deletions python/paddle/vision/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class VGG(nn.Layer):
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
features (nn.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
features (nn.Layer): Vgg features create by function make_layers.
num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): Use pool before the last three fc layer or not. Default: True.
Examples:
.. code-block:: python
Expand All @@ -54,24 +55,35 @@ class VGG(nn.Layer):
"""

def __init__(self, features, num_classes=1000):
def __init__(self, features, num_classes=1000, with_pool=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2D((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes), )
self.num_classes = num_classes
self.with_pool = with_pool

if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((7, 7))

if num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes), )

def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = paddle.flatten(x, 1)
x = self.classifier(x)

if self.with_pool:
x = self.avgpool(x)

if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.classifier(x)

return x


Expand Down

0 comments on commit 01a14e1

Please sign in to comment.