-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/sequence pool #5229
Fix/sequence pool #5229
Changes from 2 commits
15c4e8c
fa3232d
2f1785c
62b1900
b480459
f807fff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
|
||
__all__ = [ | ||
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', | ||
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool' | ||
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'sums' | ||
] | ||
|
||
|
||
|
@@ -150,7 +150,7 @@ def func(**kwargs): | |
outputs[name] = [helper.create_tmp_variable(dtype=dtype)] | ||
helper.append_op( | ||
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs) | ||
return helper.append_activation(out) | ||
return out | ||
|
||
func.__name__ = op_type | ||
globals()[op_type] = func | ||
|
@@ -160,21 +160,9 @@ def func(**kwargs): | |
|
||
_create_op_func_('mean') | ||
_create_op_func_('mul') | ||
_create_op_func_('elementwise_add') | ||
_create_op_func_('dropout') | ||
_create_op_func_('reshape') | ||
|
||
|
||
def cast(x, data_type, program=None): | ||
helper = LayerHelper('cast', **locals()) | ||
out = helper.create_tmp_variable(dtype=data_type) | ||
helper.append_op( | ||
type='cast', | ||
inputs={'X': [x]}, | ||
outputs={'Out': [out]}, | ||
attrs={'in_data_type': x.data_type, | ||
'out_data_type': out.data_type}) | ||
return out | ||
# _create_op_func_('cos_sim') | ||
|
||
|
||
def cast(x, data_type, program=None): | ||
|
@@ -202,6 +190,30 @@ def concat(input, axis, program=None, init_program=None): | |
return out | ||
|
||
|
||
def sums(input, program=None, init_program=None): | ||
helper = LayerHelper('sum', **locals()) | ||
if not isinstance(input, list) and not isinstance(input, tuple): | ||
input = [input] | ||
out = helper.create_tmp_variable(dtype=input[0].data_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def input_dtype(self, input_param_name='input'):
inputs = self.multiple_input(input_param_name)
dtype = None
for each in inputs:
if dtype is None:
dtype = each.data_type
elif dtype != each.data_type:
raise ValueError("Data Type mismatch")
return dtype There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out}) | ||
return out | ||
|
||
|
||
def cos_sim(X, Y, program=None, init_program=None): | ||
helper = LayerHelper('cos_sim', **locals()) | ||
out = helper.create_tmp_variable(dtype=X.data_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
xnorm = helper.create_tmp_variable(dtype=X.data_type) | ||
ynorm = helper.create_tmp_variable(dtype=X.data_type) | ||
helper.append_op( | ||
type='cos_sim', | ||
inputs={'X': [X], | ||
'Y': [Y]}, | ||
outputs={'Out': [out], | ||
'XNorm': [xnorm], | ||
'YNorm': [ynorm]}) | ||
return out, xnorm, ynorm | ||
|
||
|
||
def cross_entropy(input, label, **kwargs): | ||
helper = LayerHelper('cross_entropy', **kwargs) | ||
out = helper.create_tmp_variable(dtype=input.data_type) | ||
|
@@ -232,11 +244,27 @@ def square_error_cost(input, label, **kwargs): | |
return square_out | ||
|
||
|
||
def square_error_cost(input, label, **kwargs): | ||
helper = LayerHelper('square_error_cost', **kwargs) | ||
minus_out = helper.create_tmp_variable(dtype=input.data_type) | ||
helper.append_op( | ||
type='elementwise_sub', | ||
inputs={'X': [input], | ||
'Y': [label]}, | ||
outputs={'Out': [minus_out]}) | ||
|
||
square_out = helper.create_tmp_variable(dtype=input.data_type) | ||
helper.append_op( | ||
type='pow', | ||
inputs={'X': [minus_out]}, | ||
outputs={'Y': [square_out]}, | ||
attrs={'factor': 2.0}) | ||
return square_out | ||
|
||
|
||
def sequence_conv(input, | ||
num_filters, | ||
name=None, | ||
filter_size=3, | ||
act=None, | ||
stride=1, | ||
padding=None, | ||
bias_attr=None, | ||
|
@@ -250,7 +278,7 @@ def sequence_conv(input, | |
helper = LayerHelper('sequence_conv', **locals()) | ||
dtype = helper.input_dtype() | ||
|
||
filter_shape = [num_filters, filter_size] | ||
filter_shape = [num_filters * filter_size, filter_size] | ||
filter = helper.create_parameter( | ||
attr=helper.param_attr, shape=filter_shape, dtype=dtype) | ||
pre_bias = helper.create_tmp_variable(dtype) | ||
|
@@ -259,17 +287,15 @@ def sequence_conv(input, | |
type='sequence_conv', | ||
inputs={ | ||
'X': [input], | ||
'Filter': filter, | ||
'Filter': [filter], | ||
}, | ||
outputs={"Out": pre_bias}, | ||
attrs={ | ||
'context_stride': stride, | ||
'context_start': 0, | ||
'context_length': filter_size | ||
}) | ||
|
||
pre_act = helper.append_bias_op(pre_bias) | ||
return helper.append_activation(pre_act) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove activation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure sequence conv need a non-linear activation. So I remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we shall retain it. Users can set |
||
return pre_bias | ||
|
||
|
||
def conv2d(input, | ||
|
@@ -311,8 +337,8 @@ def conv2d(input, | |
helper.append_op( | ||
type='conv2d', | ||
inputs={ | ||
'Input': input, | ||
'Filter': filter, | ||
'Input': [input], | ||
'Filter': [filter], | ||
}, | ||
outputs={"Output": pre_bias}, | ||
attrs={'strides': stride, | ||
|
@@ -324,31 +350,32 @@ def conv2d(input, | |
return helper.append_activation(pre_act) | ||
|
||
|
||
def sequence_pool(input, | ||
pool_size, | ||
pool_type, | ||
pool_stride=1, | ||
pool_padding=0, | ||
global_pooling=False, | ||
program=None, | ||
init_program=None): | ||
def sequence_pool(input, pool_type, program=None, init_program=None): | ||
# FIXME(dzh) : want to unify the argument of python layer | ||
# function. So we ignore some unecessary attributes | ||
|
||
ENUM_POOL_TYPE = set(["max", "avg", "sqrt", "last", "first"]) | ||
if pool_type not in ENUM_POOL_TYPE: | ||
ENUM_POOL_TYPE = dict({ | ||
"AVERAGE": 0, | ||
"SUM": 1, | ||
"SQRT": 2, | ||
"MAX": 3, | ||
"LAST": 4, | ||
"FIRST": 5 | ||
}) | ||
if pool_type.upper() not in ENUM_POOL_TYPE: | ||
raise ValueError("Unknown pool_type: '%s'. It can only be %s.", | ||
str(pool_type), " ".join(ENUM_POOL_TYPE)) | ||
str(pool_type), " ".join(ENUM_POOL_TYPE.keys())) | ||
|
||
helper = LayerHelper('sequence_pool', **locals()) | ||
dtype = helper.input_dtype() | ||
pool_out = helper.create_tmp_variable(dtype) | ||
|
||
# FIXME(dzh): strategy | ||
helper.append_op( | ||
type="sequence_pool", | ||
inputs={"X": [input]}, | ||
outputs={"Out": pool_out}, | ||
attrs={"strategy": pool_type}) | ||
outputs={"Out": [pool_out]}, | ||
attrs={"strategy": ENUM_POOL_TYPE[pool_type.upper()]}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this change for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "max", "sqrt" are keywords of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am wondering why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because attribute of "strategy" needs a fixed enum index, not a string. AddAttr<int>(
"strategy",
"(int, default AVERAGE) the pooling strategy of SequencePoolOp.")
.SetDefault(AVERAGE)
.InEnum({AVERAGE, SUM, SQRT, MAX, LAST, FIRST}); And this attribute will be definitely replaced with another PR, and the sequence_conv attributes, etc. Those tons of code will be rewritten after I finished the book chapter 5. |
||
|
||
return pool_out | ||
|
||
|
@@ -378,8 +405,8 @@ def pool2d(input, | |
|
||
helper.append_op( | ||
type="pool2d", | ||
inputs={"X": input}, | ||
outputs={"Out": pool_out}, | ||
inputs={"X": [input]}, | ||
outputs={"Out": [pool_out]}, | ||
attrs={ | ||
"poolingType": pool_type, | ||
"ksize": pool_size, | ||
|
@@ -391,96 +418,6 @@ def pool2d(input, | |
return pool_out | ||
|
||
|
||
def batch_norm(input, | ||
act=None, | ||
is_test=False, | ||
momentum=0.9, | ||
epsilon=1e05, | ||
param_attr=None, | ||
bias_attr=None, | ||
data_layout='NCHW', | ||
program=None, | ||
init_program=None): | ||
helper = LayerHelper('batch_norm', **locals()) | ||
dtype = helper.input_dtype() | ||
|
||
input_shape = input.shape | ||
if data_layout == 'NCHW': | ||
channel_num = input_shape[1] | ||
else: | ||
if data_layout == 'NHWC': | ||
channel_num = input_shape[-1] | ||
else: | ||
raise ValueError("unsupported data layout:" + data_layout) | ||
|
||
def get_init_attr(value): | ||
if not isinstance(value, float): | ||
raise ValueError("attr value should be a float") | ||
return {'type': 'fill_constant', 'value': value} | ||
|
||
def prepend_init_op(var, init_attr): | ||
assert isinstance(var, Variable) | ||
op_type = init_attr['type'] | ||
init_attr['shape'] = var.shape | ||
init_attr['data_type'] = int(var.data_type) | ||
op = var.block.prepend_op( | ||
type=op_type, inputs=None, outputs={'Out': [var]}, attrs=init_attr) | ||
return op | ||
|
||
def create_persistable_var(dtype, shape, init_attr=None): | ||
name = unique_name(".".join([helper.name, "xxxx"])) | ||
var = init_program.global_block().create_var( | ||
dtype=dtype, shape=shape, name=name, persistable=True) | ||
if 'init_attr' is not None: | ||
prepend_init_op(var, init_attr) | ||
return program.global_block().create_var( | ||
name=name, dtype=dtype, shape=shape, persistable=True) | ||
|
||
param_shape = [channel_num] | ||
|
||
# create parameter | ||
scale = helper.create_parameter( | ||
attr=helper.param_attr, shape=param_shape, dtype=dtype) | ||
bias = helper.create_parameter( | ||
attr=helper.param_attr, shape=param_shape, dtype=dtype) | ||
|
||
# create input | ||
mean = create_persistable_var(dtype, param_shape, get_init_attr(0.0)) | ||
variance = create_persistable_var(dtype, param_shape, get_init_attr(1.0)) | ||
|
||
# create output | ||
# mean and mean_out share the same memory | ||
mean_out = mean | ||
# variance and variance out share the same memory | ||
variance_out = variance | ||
saved_mean = helper.create_tmp_variable(dtype) | ||
saved_variance = helper.create_tmp_variable(dtype) | ||
|
||
batch_norm_out = helper.create_tmp_variable(dtype) | ||
|
||
helper.append_op( | ||
type="batch_norm", | ||
inputs={ | ||
"X": input, | ||
"Scale": scale, | ||
"Bias": bias, | ||
"Mean": mean, | ||
"Variance": variance | ||
}, | ||
outputs={ | ||
"Y": batch_norm_out, | ||
"MeanOut": mean_out, | ||
"VarianceOut": variance_out, | ||
"SavedMean": saved_mean, | ||
"SavedVariance": saved_variance | ||
}, | ||
attrs={"momentum": momentum, | ||
"epsilon": epsilon, | ||
"is_test": is_test}) | ||
|
||
return helper.append_activation(batch_norm_out) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why remove batch_norm layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a mistake remove caused by the conflict. It has been fixed. |
||
class BlockGuard(object): | ||
""" | ||
BlockGuard used to create sub-block in program by using Python `with` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,22 +103,18 @@ def sequence_conv_pool(input, | |
filter_size, | ||
pool_size, | ||
pool_stride, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
act, | ||
program=None, | ||
init_program=None): | ||
conv_out = layers.sequence_conv( | ||
input=input, | ||
num_filters=num_filters, | ||
filter_size=filter_size, | ||
act=act, | ||
program=program, | ||
init_program=init_program) | ||
|
||
pool_out = layers.sequence_pool( | ||
input=conv_out, | ||
pool_size=pool_size, | ||
pool_type='max', | ||
pool_stride=pool_stride, | ||
pool_type='sum', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
program=program, | ||
init_program=init_program) | ||
return pool_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to do the converting here. Operator's constructor will do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.