Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

make gluon rnn layers hybrid blocks #11482

Merged
merged 6 commits into from
Aug 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 62 additions & 70 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray
from .. import Block
from ... import ndarray, symbol
from .. import HybridBlock, tensor_types
from . import rnn_cell


class _RNNLayer(Block):
class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
Expand All @@ -52,33 +51,28 @@ def __init__(self, hidden_size, num_layers, layout,

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

self.i2h_weight = []
self.h2h_weight = []
self.i2h_bias = []
self.h2h_bias = []

ng, ni, nh = self._gates, input_size, hidden_size
for i in range(num_layers):
for j in (['l', 'r'] if self._dir == 2 else ['l']):
self.i2h_weight.append(
self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni),
init=i2h_weight_initializer,
allow_deferred_init=True))
self.h2h_weight.append(
self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh),
init=h2h_weight_initializer,
allow_deferred_init=True))
self.i2h_bias.append(
self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,),
init=i2h_bias_initializer,
allow_deferred_init=True))
self.h2h_bias.append(
self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,),
init=h2h_bias_initializer,
allow_deferred_init=True))
for j in ['l', 'r'][:self._dir]:
self._register_param('{}{}_i2h_weight'.format(j, i),
shape=(ng*nh, ni),
init=i2h_weight_initializer)
self._register_param('{}{}_h2h_weight'.format(j, i),
shape=(ng*nh, nh),
init=h2h_weight_initializer)
self._register_param('{}{}_i2h_bias'.format(j, i),
shape=(ng*nh,),
init=i2h_bias_initializer)
self._register_param('{}{}_h2h_bias'.format(j, i),
shape=(ng*nh,),
init=h2h_bias_initializer)
ni = nh * self._dir

self._unfused = self._unfuse()
def _register_param(self, name, shape, init):
p = self.params.get(name, shape=shape, init=init,
allow_deferred_init=True)
setattr(self, name, p)
return p

def __repr__(self):
s = '{name}({mapping}, {_layout}'
Expand All @@ -89,12 +83,23 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
shape = self.i2h_weight[0].shape
shape = self.l0_i2h_weight.shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def _collect_params_with_prefix(self, prefix=''):
if prefix:
prefix += '.'
def convert_key(key): # for compatibility with old parameter format
key = key.split('_')
return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:]))
ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()}
for name, child in self._children.items():
ret.update(child._collect_params_with_prefix(prefix + name))
return ret

def state_info(self, batch_size=0):
raise NotImplementedError

Expand All @@ -111,7 +116,7 @@ def _unfuse(self):
'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
**kwargs)}[self._mode]

stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params)
stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params)
with stack.name_scope():
ni = self._input_size
for i in range(self._num_layers):
Expand Down Expand Up @@ -169,63 +174,50 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
if F is ndarray:
states = self.begin_state(batch_size, ctx=inputs.context)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
out = self._forward_kernel(inputs, states)
if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)

# out is (output, state)
return out[0] if skip_states else out

def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
outputs, states = self._unfused.unroll(
inputs.shape[axis], inputs, states,
layout=self._layout, merge_outputs=True)
new_states = []
for i in range(ns):
state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0)
new_states.append(state)

return outputs, new_states

def _forward_kernel(self, inputs, states):
def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
params = sum(zip(self.i2h_weight, self.h2h_weight), ())
params += sum(zip(self.i2h_bias, self.h2h_bias), ())
params = (i.data(ctx).reshape((-1,)) for i in params)
params = ndarray.concat(*params, dim=0)

rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
for l in range(self._num_layers)
for d in ['l', 'r'][:self._dir]
for g in ['i2h', 'h2h'])
params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]

if self._layout == 'NTC':
outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
outputs = F.swapaxes(outputs, dim1=0, dim2=1)

return outputs, states

Expand Down
127 changes: 102 additions & 25 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,65 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
return dshape.Size() != 0;
}

// Concat for RNN param deals with the reverse shape inference from output
// for the special case of concatenating RNN parameters.
// The first (and sometimes the second) input may be unknown on the target axis.
// If the two inputs are unknown, they always have the same shape.
static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
TShape dshape;
index_t size = 0;
int num_zero = 0;
int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
TShape tmp = (*in_shape)[i];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
num_zero += tmp[axis] == 0;
size += tmp[axis];
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
}

TShape tmp = (*out_shape)[0];
if (tmp.ndim()) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = 0;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == 0) return false;

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!num_zero) dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
if ((*out_shape)[0][axis] != 0 && num_zero) {
int residual = (*out_shape)[0][axis] - size;
CHECK_GE(residual, 0)
<< "Input size already exceeds output size. Residual: " << residual;
CHECK(num_zero <= 2 && num_zero >= 0)
<< "Expecting 1 or 2 inputs that need shape inference. Got: " << num_zero;
bool need_infer = !(*out_shape)[0].Size();
for (int i = 0; i < num_zero; i++) {
(*in_shape)[i*2][axis] = residual / num_zero;
need_infer = need_infer || !(*in_shape)[i].Size();
}
return !need_infer;
}

return dshape.Size() != 0;
}

static bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type) {
Expand Down Expand Up @@ -228,6 +287,34 @@ struct ConcatGrad {

DMLC_REGISTER_PARAMETER(ConcatParam);

#define CONCAT_FORWARD_ATTRS \
.set_num_inputs([](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
return params.num_args; \
}) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ConcatParam>) \
.set_attr<nnvm::FListInputNames>("FListInputNames", \
[](const NodeAttrs& attrs) { \
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
std::vector<std::string> ret; \
for (int i = 0; i < params.num_args; ++i) { \
ret.push_back(std::string("arg") + std::to_string(i)); \
} \
return ret; \
}) \
.set_attr<nnvm::FListOutputNames>("FListOutputNames", \
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"output"}; \
}) \
.set_attr<nnvm::FInferType>("FInferType", ConcatType) \
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
.set_attr<std::string>("key_var_num_args", "num_args")


NNVM_REGISTER_OP(Concat)
MXNET_ADD_SPARSE_OP_ALIAS(concat)
.add_alias("concat")
Expand Down Expand Up @@ -268,37 +355,13 @@ Example::
[ 5., 5., 8., 8.]]

)code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("arg") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU)
.set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"})
.set_attr<std::string>("key_var_num_args", "num_args")
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

Expand All @@ -320,5 +383,19 @@ NNVM_REGISTER_OP(_backward_Concat)
#endif
.set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);

// _rnn_param_concat is a custom concat op with specialized infer_shape,
// which handles the case where the first one or two inputs may have
// unknown shape that can be inferred from output shape.
NNVM_REGISTER_OP(_rnn_param_concat)
#if MXNET_USE_MKLDNN == 1
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
CONCAT_FORWARD_ATTRS
.set_attr<nnvm::FInferShape>("FInferShape", RNNParamConcatShape)
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

} // namespace op
} // namespace mxnet
4 changes: 4 additions & 0 deletions src/operator/nn/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_rnn_param_concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);

NNVM_REGISTER_OP(_backward_Concat)
.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);

Expand Down
6 changes: 3 additions & 3 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx,
DMLC_REGISTER_PARAMETER(RNNParam);

MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
implemented, with both multi-layer and bidirectional support.
**Vanilla RNN**
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported:
ReLU and Tanh.
With ReLU activation function:
Expand All @@ -63,7 +63,7 @@ With Tanh activtion function:
.. math::
h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh})
Reference paper: Finding structure in time - Elman, 1988.
Reference paper: Finding structure in time - Elman, 1988.
https://crl.ucsd.edu/~elman/Papers/fsit.pdf
**LSTM**
Expand Down
Loading