Skip to content

Commit

Permalink
Merge pull request apache#42 from heliqi/paddle_frontend
Browse files Browse the repository at this point in the history
add masked_select meshgrid scatter scatter_nd_add op
  • Loading branch information
jiangjiajun authored Sep 18, 2021
2 parents c5b5273 + 15b44d8 commit 82345f2
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 21 deletions.
87 changes: 74 additions & 13 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ class ControlFlow:

@classmethod
def convert_block(cls, graph, block):
for i, op in enumerate(block.ops):
for op in block.ops:
if op.type in ControlFlow.operators:
raise Exception("Nested Control Flow Not Support Yet.")
else:
convert_func = _convert_map[op.type]
convert_func(graph, op, block)
convert_func = _convert_map[op.type]
convert_func(graph, op, block)

@classmethod
def convert(cls, graph, op, program):
Expand All @@ -66,6 +65,8 @@ def convert(cls, graph, op, program):

@classmethod
def convert_while(cls, graph, op, program):
"""Operator converter for while."""

sub_block_id = op.attr("sub_block").id
sub_block = program.blocks[sub_block_id]
input_names = op.input("X")
Expand All @@ -77,7 +78,6 @@ def convert_while(cls, graph, op, program):
continue
if name not in input_names:
raise Exception("Output '{}' not in inputs".format(name))
inputs = [graph.get_node(x) for x in op.input("X")]

sub_graph = GraphProto(graph.freeze_params)
sub_graph.set_params(graph.get_params())
Expand All @@ -96,7 +96,6 @@ def cond_fn(*loop_inputs):
return _op.equal(squeezed_cond, _expr.const(True, "bool"))

def body_fn(*loop_inputs):
cond = loop_inputs[0]
body_inputs = loop_inputs[1:]
for i, ipt in enumerate(body_inputs):
sub_graph.add_node(input_names[i], ipt)
Expand Down Expand Up @@ -141,13 +140,13 @@ def _dtype_shape_promotion(inputs):
if r == 0:
inputs[i] = _op.expand_dims(inputs[i], axis=0)

dtypes = set([dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs])
dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in inputs)
if len(dtypes) == 1:
return inputs
max_dtype = dtype_order(max(dtypes))
for i in range(len(inputs)):
if infer_type(inputs[i]).checked_type.dtype != max_dtype:
inputs[i] = inputs[i].astype(max_dtype)
max_dtype = dtype_order[max(dtypes)]
for i, input_op in enumerate(inputs):
if infer_type(input_op).checked_type.dtype != max_dtype:
inputs[i] = input_op.astype(max_dtype)
return inputs


Expand Down Expand Up @@ -1145,6 +1144,20 @@ def convert_logsumexp(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_masked_select(g, op, block):
"""Operator converter for masked_select."""

x = g.get_node(op.input("X")[0])
mask = g.get_node(op.input("Mask")[0])
index = _op.transform.argwhere(mask)
shape = infer_shape(index)
perm = list(range(0, len(shape) - 1))
perm.insert(0, len(shape) - 1)
index = _op.transpose(index, axes=perm)
out = _op.gather_nd(x, index, 0, shape[-1])
g.add_node(op.output("Y")[0], out)


def convert_matmul(g, op, block):
"""Operator converter for matmul."""

Expand Down Expand Up @@ -1254,6 +1267,16 @@ def flatten_to_nd(x, x_shape, nd=3):
g.add_node(op.output("Out")[0], out)


def convert_meshgrid(g, op, block):
"""Operator converter for meshgrid."""

inputs = op.input("X")
x = [g.get_node(i) for i in inputs]
outs = _op.meshgrid(x, indexing="ij")
for i, out in enumerate(outs):
g.add_node(op.output("Out")[i], out)


def convert_mul(g, op, block):
"""Operator converter for mul."""

Expand Down Expand Up @@ -1733,10 +1756,44 @@ def convert_scale(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_scatter(g, op, block):
"""Operator converter for scatter."""

x = g.get_node(op.input("X")[0])
index = g.get_node(op.input("Ids")[0])
updates = g.get_node(op.input("Updates")[0])
overwrite = op.attr("overwrite")

shape = infer_shape(updates)
ndims = len(shape)
index = _op.expand_dims(index, axis=-1, num_newaxis=ndims - 1)
index = _op.transform.broadcast_to(index, shape)

if overwrite:
out = _op.scatter(x, index, updates, axis=0)
else:
out = _op.scatter_add(_op.zeros_like(x), index, updates, axis=0)
out += _op.scatter(x, index, _op.zeros_like(updates), axis=0)
g.add_node(op.output("Out")[0], out)


def convert_scatter_nd_add(g, op, block):
"""Operator converter for scatter_nd_add."""

x = g.get_node(op.input("X")[0])
index = g.get_node(op.input("Index")[0])
updates = g.get_node(op.input("Updates")[0])
indices_dim = len(infer_shape(index))
axes = list(range(indices_dim))
index = _op.transpose(index, axes[-1:] + axes[:-1])
out = _op.scatter_nd(x, index, updates, mode="add")
g.add_node(op.output("Out")[0], out)


def convert_selu(g, op, block):
"""Operator converter for selu."""

x = g.get_node(op.input("x")[0])
x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
alpha = _op.const(op.attr("alpha"), dtype)
scale = _op.const(op.attr("scale"), dtype)
Expand Down Expand Up @@ -2211,8 +2268,10 @@ def convert_where(g, op, block):
"logsigmoid": convert_logsigmoid,
"log_softmax": convert_logsoftmax,
"logsumexp": convert_logsumexp,
"masked_select": convert_masked_select,
"matmul": convert_matmul,
"matmul_v2": convert_matmul,
"meshgrid": convert_meshgrid,
"mv": convert_mv,
"mul": convert_mul,
"nearest_interp_v2": convert_interpolate,
Expand Down Expand Up @@ -2241,6 +2300,8 @@ def convert_where(g, op, block):
"round": convert_unary_op,
"rsqrt": convert_unary_op,
"scale": convert_scale,
"scatter": convert_scatter,
"scatter_nd_add": convert_scatter_nd_add,
"selu": convert_selu,
"shape": convert_shape,
"sigmoid": convert_unary_op,
Expand Down Expand Up @@ -2370,7 +2431,7 @@ def ops_to_relay(self, program, input_specs=None):
for input_spec in input_specs:
convert_feed(self, input_spec, None)
global_block = program.blocks[0]
for i, op in enumerate(global_block.ops):
for op in global_block.ops:
if op.type == "fetch":
continue
if op.type in ControlFlow.operators:
Expand Down
100 changes: 92 additions & 8 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,9 +875,9 @@ def __init__(self, channels, groups):
def forward(self, inputs):
return self.group_norm(inputs)

input_shape = [2, 6, 2, 2]
input_shape = [2, 6, 10, 10]
x = paddle.rand(input_shape, dtype="float32")
verify_model(GroupNorm(6, 6), x)
verify_model(GroupNorm(6, 6), x, rtol=1e-4, atol=1e-4)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -906,7 +906,6 @@ def forward(self, inputs):
"hardtanh",
"log_sigmoid",
"log_softmax",
"relu",
"relu6",
"selu",
"sigmoid",
Expand All @@ -919,7 +918,7 @@ def forward(self, inputs):
]
for op_name in op_list:
verify_model(Activation(op_name), input_data=input_data)
verify_model(Activation(op_name), input_data=input_data_2, rtol=1e-9, atol=1e-6)
verify_model(Activation(op_name), input_data=input_data_2)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1170,6 +1169,32 @@ def multiply3(inputs, inputs2):
verify_model(multiply3, input_data=[input_data, input_data2])


@tvm.testing.uses_gpu
def test_forward_masked_select():
@paddle.jit.to_static
def masked_select(x):
mask_data = np.array(
[[True, False, False, False], [True, True, False, False], [True, False, False, False]]
).astype("bool")
mask = paddle.to_tensor(mask_data)
mask = paddle.logical_not(mask)
return paddle.masked_select(x, mask)

@paddle.jit.to_static
def masked_select2(x, mask):
return paddle.masked_select(x, mask)

data = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]]).astype(
"float32"
)
x = paddle.to_tensor(data)
verify_model(masked_select, x)
input_shape = [2, 3, 10]
x = paddle.rand(input_shape, dtype="float32")
mask = paddle.randint(0, 2, input_shape).astype("bool")
verify_model(masked_select2, [x, mask], input_shape=[input_shape, input_shape])


@tvm.testing.uses_gpu
def test_forward_matmul():
class MatMul1(nn.Layer):
Expand Down Expand Up @@ -1198,6 +1223,17 @@ def forward(self, input1, input2):


@tvm.testing.uses_gpu
def test_forward_meshgrid():
@paddle.jit.to_static
def t(x, y, z):
return paddle.meshgrid(x, y, z)

x = paddle.randint(low=0, high=100, shape=[2])
y = paddle.randint(low=0, high=100, shape=[3])
z = paddle.randint(low=0, high=100, shape=[5])
verify_model(t, [x, y, z])


def test_forward_mm():
class Mm(nn.Layer):
def forward(self, input1, input2):
Expand Down Expand Up @@ -1593,6 +1629,50 @@ def scale2(inputs):
verify_model(scale2, input_data=input_data)


@tvm.testing.uses_gpu
def test_forward_scatter():
@paddle.jit.to_static
def scatter(x, index, updates):
return paddle.scatter(x, index, updates, overwrite=True)

@paddle.jit.to_static
def scatter2(x, index, updates):
return paddle.scatter(x, index, updates, overwrite=False)

x = paddle.rand([10, 8, 5], dtype="float32")
index = paddle.to_tensor(
[
2,
1,
0,
6,
]
)
updates = paddle.rand([4, 8, 5], dtype="float32")
verify_model(scatter, [x, index, updates])
verify_model(scatter2, [x, index, updates])


def test_forward_scatter_nd():
@paddle.jit.to_static
def scatter_nd(index, updates):
shape = [3, 5, 9, 10]
return paddle.scatter_nd(index, updates, shape)

@paddle.jit.to_static
def scatter_nd_add(x, index, updates):
return paddle.scatter_nd_add(x, index, updates)

index_data = np.array([[1, 1], [0, 1], [1, 3]]).astype(np.int64)
index = paddle.to_tensor(index_data)
updates = paddle.rand(shape=[3, 9, 10], dtype="float32")
verify_model(scatter_nd, [index, updates])
x = paddle.rand(shape=[3, 5, 4, 9, 10], dtype="float32")
updates = paddle.rand(shape=[3, 2, 9, 10], dtype="float32")
index = paddle.randint(0, 4, shape=[3, 2, 3])
verify_model(scatter_nd_add, [x, index, updates])


@tvm.testing.uses_gpu
def test_forward_slice():
@paddle.jit.to_static
Expand Down Expand Up @@ -1735,7 +1815,7 @@ class Std6(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
return paddle.std(inputs, unbiased=False)

class Std7(nn.Layer):
@paddle.jit.to_static
def forward(self, inputs):
Expand All @@ -1759,8 +1839,8 @@ class Subtract(nn.Layer):
def forward(self, x, y):
return paddle.subtract(x, y)

input_data1 = paddle.to_tensor([2, np.nan, 5], dtype='float32')
input_data2 = paddle.to_tensor([1, 4, np.nan], dtype='float32')
input_data1 = paddle.to_tensor([2, np.nan, 5], dtype="float32")
input_data2 = paddle.to_tensor([1, 4, np.nan], dtype="float32")
verify_model(Subtract(), input_data=[input_data1, input_data2])

input_data1 = paddle.randint(0, 10, (3, 4), dtype="int32")
Expand Down Expand Up @@ -1893,7 +1973,7 @@ def unique3(x):
@paddle.jit.to_static
def unique4(x):
return paddle.unique(x, return_index=False, return_inverse=False, return_counts=True)

@paddle.jit.to_static
def unique5(x):
return paddle.unique(x, return_index=True, return_inverse=True, return_counts=False)
Expand Down Expand Up @@ -2018,7 +2098,9 @@ def forward(self, x):
test_forward_logical_op()
test_forward_look_up()
test_forward_lstm()
test_forward_masked_select()
test_forward_matmul()
test_forward_meshgrid
test_forward_mm()
test_forward_mv()
test_forward_multiply()
Expand All @@ -2033,6 +2115,8 @@ def forward(self, x):
test_forward_reduce_op()
test_forward_reshape()
test_forward_scale()
test_forward_scatter()
test_forward_scatter_nd()
test_forward_slice()
test_forward_sort()
test_forward_split()
Expand Down

0 comments on commit 82345f2

Please sign in to comment.