Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support both Gluon 1 and 2 in the hybrid containers (#19470)
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu authored Nov 4, 2020
1 parent b33fbd1 commit 3d1df4e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
7 changes: 4 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def forward(self, x):
"""
def __init__(self):
super(HybridBlock, self).__init__()
self._v2 = inspect.unwrap(self.hybrid_forward.__func__) is HybridBlock.hybrid_forward
self._cached_graph = ()
self._cached_op = None
self._out_format = None
Expand Down Expand Up @@ -985,7 +986,7 @@ def _get_graph_v2(self, *args):

def _get_graph(self, *args):
if not self._cached_graph:
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
return self._get_graph_v1(*args)
else: # Gluon 2 based on deferred compute mode
return self._get_graph_v2(*args)
Expand Down Expand Up @@ -1282,7 +1283,7 @@ def _infer_attrs(self, infer_fn, attr, *args):

def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
# Gluon 1 based on F: hybrid_forward is defined by user
self._infer_attrs('infer_shape', 'shape', *args)
else:
Expand Down Expand Up @@ -1406,7 +1407,7 @@ def c_callback(name, op_name, array):
cld()._monitor_all = monitor_all

def __call__(self, x, *args):
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
# Gluon 1 based on F: hybrid_forward is defined by user
return super().__call__(x, *args)
else: # Gluon 2 based on deferred compute mode
Expand Down
43 changes: 40 additions & 3 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
'Flatten', 'Lambda', 'HybridLambda', 'Concatenate', 'HybridConcatenate', 'Identity']
import warnings
import uuid
import inspect
import numpy as np

from .activations import Activation
from ..block import Block, HybridBlock
from ..utils import _indent
from ... import ndarray as nd, symbol as sym, context
from ... import ndarray as nd, np as mxnp, symbol as sym, context, _deferred_compute as dc
from ...util import is_np_array
from ..parameter import Parameter

Expand Down Expand Up @@ -111,15 +112,41 @@ class HybridSequential(HybridBlock):
net.hybridize()
"""
def __init__(self):
super(HybridSequential, self).__init__()
super().__init__()
self._layers = []
self._v2_checked = False

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

def __call__(self, *args, **kwargs):
if self._active and not self._v2_checked and not dc.is_deferred_compute():
# If any of the child Blocks implements the Gluon 2 interface, the
# container must not pass a Symbol to them
if any(inspect.unwrap(chld().hybrid_forward.__func__) is
HybridBlock.hybrid_forward for chld in self._children.values()):
self._v2 = True
self._v2_checked = True
self.forward = self._forward

return super().__call__(*args, **kwargs)


def _forward(self, x, *args):
for block in self._children.values():
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
x = x[0]
if args:
x = tuple([x] + list(args))
return x


def hybrid_forward(self, F, x, *args):
for block in self._children.values():
x = block()(x, *args)
Expand Down Expand Up @@ -997,9 +1024,19 @@ class HybridConcatenate(HybridSequential):
The axis on which to concatenate the outputs.
"""
def __init__(self, axis=-1):
super(HybridConcatenate, self).__init__()
super().__init__()
self.axis = axis

def _forward(self, x):
out = []
for block in self._children.values():
out.append(block()(x))
if is_np_array():
out = mxnp.concatenate(out, axis=self.axis)
else:
out = nd.concat(*out, dim=self.axis)
return out

def hybrid_forward(self, F, x):
out = []
for block in self._children.values():
Expand Down
70 changes: 60 additions & 10 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,16 +1011,47 @@ def check_sequential(net):
net.add(dense2)
dense3 = gluon.nn.Dense(10)
net.add(dense3)
net.initialize()

net(mx.nd.zeros((10, 10)))
net.hybridize()
assert net[1] is dense2
assert net[-1] is dense3
slc = net[1:3]
assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3
assert isinstance(slc, type(net))

def check_sequential_dc(net):
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
self.dense = mx.gluon.nn.Dense(units=10, in_units=10)
self.weight = mx.gluon.Parameter('weight', shape=(10, ))

def forward(self, x):
return self.dense(x) + self.weight.data()

dense1 = MyBlock()
net.add(dense1)
dense2 = MyBlock()
net.add(dense2)
dense3 = MyBlock()
net.add(dense3)

net.initialize()
net.hybridize()
net(mx.nd.zeros((10, 10)))
assert net[1] is dense2
assert net[-1] is dense3
slc = net[1:3]
assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3
assert isinstance(slc, type(net))

@pytest.mark.garbage_expected
def test_sequential():
check_sequential(gluon.nn.Sequential())
check_sequential(gluon.nn.HybridSequential())
check_sequential_dc(gluon.nn.HybridSequential())

def test_sequential_warning():
with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -3075,24 +3106,43 @@ def test_ModulatedDeformableConvolution():
with mx.autograd.record():
y = net(x)

def test_concatenate():

@pytest.mark.parametrize('dc', [True, False])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.garbage_expected
def test_concatenate(dc, hybridize):
if dc:
class MyBlock(mx.gluon.HybridBlock):
def __init__(self, units, activation=None, in_units=0):
super().__init__()
self.dense = mx.gluon.nn.Dense(units, activation=activation, in_units=in_units)

def forward(self, x):
return self.dense(x)
else:
MyBlock = nn.Dense

model = nn.HybridConcatenate(axis=1)
model.add(nn.Dense(128, activation='tanh', in_units=10))
model.add(nn.Dense(64, activation='tanh', in_units=10))
model.add(nn.Dense(32, in_units=10))
model.add(MyBlock(128, activation='tanh', in_units=10))
model.add(MyBlock(64, activation='tanh', in_units=10))
model.add(MyBlock(32, in_units=10))
model2 = nn.Concatenate(axis=1)
model2.add(nn.Dense(128, activation='tanh', in_units=10))
model2.add(nn.Dense(64, activation='tanh', in_units=10))
model2.add(nn.Dense(32, in_units=10))
model2.add(MyBlock(128, activation='tanh', in_units=10))
model2.add(MyBlock(64, activation='tanh', in_units=10))
model2.add(MyBlock(32, in_units=10))

# symbol
x = mx.sym.var('data')
y = model(x)
assert len(y.list_arguments()) == 7
if not dc:
x = mx.sym.var('data')
y = model(x)
assert len(y.list_arguments()) == 7

# ndarray
model.initialize(mx.init.Xavier(magnitude=2.24))
model2.initialize(mx.init.Xavier(magnitude=2.24))
if hybridize:
model.hybridize()
model2.hybridize()
x = model(mx.nd.zeros((32, 10)))
x2 = model2(mx.nd.zeros((32, 10)))
assert x.shape == (32, 224)
Expand Down

0 comments on commit 3d1df4e

Please sign in to comment.