Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support both Gluon 1 and 2 in the hybrid containers #19470

Merged
merged 1 commit into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def forward(self, x):
"""
def __init__(self):
super(HybridBlock, self).__init__()
self._v2 = inspect.unwrap(self.hybrid_forward.__func__) is HybridBlock.hybrid_forward
self._cached_graph = ()
self._cached_op = None
self._out_format = None
Expand Down Expand Up @@ -985,7 +986,7 @@ def _get_graph_v2(self, *args):

def _get_graph(self, *args):
if not self._cached_graph:
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
return self._get_graph_v1(*args)
else: # Gluon 2 based on deferred compute mode
return self._get_graph_v2(*args)
Expand Down Expand Up @@ -1282,7 +1283,7 @@ def _infer_attrs(self, infer_fn, attr, *args):

def infer_shape(self, *args):
"""Infers shape of Parameters from inputs."""
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
# Gluon 1 based on F: hybrid_forward is defined by user
self._infer_attrs('infer_shape', 'shape', *args)
else:
Expand Down Expand Up @@ -1406,7 +1407,7 @@ def c_callback(name, op_name, array):
cld()._monitor_all = monitor_all

def __call__(self, x, *args):
if inspect.unwrap(self.hybrid_forward.__func__) is not HybridBlock.hybrid_forward:
if not self._v2:
# Gluon 1 based on F: hybrid_forward is defined by user
return super().__call__(x, *args)
else: # Gluon 2 based on deferred compute mode
Expand Down
43 changes: 40 additions & 3 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
'Flatten', 'Lambda', 'HybridLambda', 'Concatenate', 'HybridConcatenate', 'Identity']
import warnings
import uuid
import inspect
import numpy as np

from .activations import Activation
from ..block import Block, HybridBlock
from ..utils import _indent
from ... import ndarray as nd, symbol as sym, context
from ... import ndarray as nd, np as mxnp, symbol as sym, context, _deferred_compute as dc
from ...util import is_np_array
from ..parameter import Parameter

Expand Down Expand Up @@ -111,15 +112,41 @@ class HybridSequential(HybridBlock):
net.hybridize()
"""
def __init__(self):
super(HybridSequential, self).__init__()
super().__init__()
self._layers = []
self._v2_checked = False

def add(self, *blocks):
"""Adds block on top of the stack."""
for block in blocks:
self._layers.append(block)
self.register_child(block)

def __call__(self, *args, **kwargs):
if self._active and not self._v2_checked and not dc.is_deferred_compute():
# If any of the child Blocks implements the Gluon 2 interface, the
# container must not pass a Symbol to them
if any(inspect.unwrap(chld().hybrid_forward.__func__) is
HybridBlock.hybrid_forward for chld in self._children.values()):
self._v2 = True
self._v2_checked = True
self.forward = self._forward

return super().__call__(*args, **kwargs)


def _forward(self, x, *args):
for block in self._children.values():
x = block()(x, *args)
args = []
if isinstance(x, (tuple, list)):
args = x[1:]
x = x[0]
if args:
x = tuple([x] + list(args))
return x


def hybrid_forward(self, F, x, *args):
for block in self._children.values():
x = block()(x, *args)
Expand Down Expand Up @@ -997,9 +1024,19 @@ class HybridConcatenate(HybridSequential):
The axis on which to concatenate the outputs.
"""
def __init__(self, axis=-1):
super(HybridConcatenate, self).__init__()
super().__init__()
self.axis = axis

def _forward(self, x):
out = []
for block in self._children.values():
out.append(block()(x))
if is_np_array():
out = mxnp.concatenate(out, axis=self.axis)
else:
out = nd.concat(*out, dim=self.axis)
return out

def hybrid_forward(self, F, x):
out = []
for block in self._children.values():
Expand Down
70 changes: 60 additions & 10 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,16 +1011,47 @@ def check_sequential(net):
net.add(dense2)
dense3 = gluon.nn.Dense(10)
net.add(dense3)
net.initialize()

net(mx.nd.zeros((10, 10)))
net.hybridize()
assert net[1] is dense2
assert net[-1] is dense3
slc = net[1:3]
assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3
assert isinstance(slc, type(net))

def check_sequential_dc(net):
class MyBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
self.dense = mx.gluon.nn.Dense(units=10, in_units=10)
self.weight = mx.gluon.Parameter('weight', shape=(10, ))

def forward(self, x):
return self.dense(x) + self.weight.data()

dense1 = MyBlock()
net.add(dense1)
dense2 = MyBlock()
net.add(dense2)
dense3 = MyBlock()
net.add(dense3)

net.initialize()
net.hybridize()
net(mx.nd.zeros((10, 10)))
assert net[1] is dense2
assert net[-1] is dense3
slc = net[1:3]
assert len(slc) == 2 and slc[0] is dense2 and slc[1] is dense3
assert isinstance(slc, type(net))

@pytest.mark.garbage_expected
def test_sequential():
check_sequential(gluon.nn.Sequential())
check_sequential(gluon.nn.HybridSequential())
check_sequential_dc(gluon.nn.HybridSequential())

def test_sequential_warning():
with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -3075,24 +3106,43 @@ def test_ModulatedDeformableConvolution():
with mx.autograd.record():
y = net(x)

def test_concatenate():

@pytest.mark.parametrize('dc', [True, False])
@pytest.mark.parametrize('hybridize', [True, False])
@pytest.mark.garbage_expected
def test_concatenate(dc, hybridize):
if dc:
class MyBlock(mx.gluon.HybridBlock):
def __init__(self, units, activation=None, in_units=0):
super().__init__()
self.dense = mx.gluon.nn.Dense(units, activation=activation, in_units=in_units)

def forward(self, x):
return self.dense(x)
else:
MyBlock = nn.Dense

model = nn.HybridConcatenate(axis=1)
model.add(nn.Dense(128, activation='tanh', in_units=10))
model.add(nn.Dense(64, activation='tanh', in_units=10))
model.add(nn.Dense(32, in_units=10))
model.add(MyBlock(128, activation='tanh', in_units=10))
model.add(MyBlock(64, activation='tanh', in_units=10))
model.add(MyBlock(32, in_units=10))
model2 = nn.Concatenate(axis=1)
model2.add(nn.Dense(128, activation='tanh', in_units=10))
model2.add(nn.Dense(64, activation='tanh', in_units=10))
model2.add(nn.Dense(32, in_units=10))
model2.add(MyBlock(128, activation='tanh', in_units=10))
model2.add(MyBlock(64, activation='tanh', in_units=10))
model2.add(MyBlock(32, in_units=10))

# symbol
x = mx.sym.var('data')
y = model(x)
assert len(y.list_arguments()) == 7
if not dc:
x = mx.sym.var('data')
y = model(x)
assert len(y.list_arguments()) == 7

# ndarray
model.initialize(mx.init.Xavier(magnitude=2.24))
model2.initialize(mx.init.Xavier(magnitude=2.24))
if hybridize:
model.hybridize()
model2.hybridize()
x = model(mx.nd.zeros((32, 10)))
x2 = model2(mx.nd.zeros((32, 10)))
assert x.shape == (32, 224)
Expand Down