Skip to content

Commit

Permalink
finished common api
Browse files Browse the repository at this point in the history
  • Loading branch information
wjj19950828 committed Sep 17, 2021
1 parent bd64abf commit 2355add
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
60 changes: 54 additions & 6 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,9 @@ def make_init_param_inputs(g, node, layer):
if is_bidirec:
num_directions = 2

X_shape = infer_shape(input_x)
time_steps = X_shape[0]
X_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0)
x_shape = infer_shape(input_x)
time_steps = x_shape[0]
x_steps = _op.split(input_x, indices_or_sections=time_steps, axis=0)
for layer in range(num_layers):
input_weights, hidden_weights, input_bias, hidden_bias = make_param_inputs(
g, op, layer, hidden_size, num_layers
Expand All @@ -1642,7 +1642,7 @@ def make_init_param_inputs(g, node, layer):
WB = g.get_node(input_bias[i])
RB = g.get_node(hidden_bias[i])
output, H, C = generate_lstm(
X_steps=X_steps,
X_steps=x_steps,
H_t=H_t,
C_t=C_t,
W=W,
Expand All @@ -1663,8 +1663,7 @@ def make_init_param_inputs(g, node, layer):

output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, newshape=(0, 0, -1))
X_steps = output
X_steps = _op.split(X_steps, indices_or_sections=time_steps, axis=0)
x_steps = _op.split(output, indices_or_sections=time_steps, axis=0)

g.add_node(op.output("Out")[0], output)

Expand Down Expand Up @@ -1963,6 +1962,53 @@ def convert_unsqueeze(g, op, block):
g.add_node(op.output("Out")[0], x)


def convert_unstack(g, op, block):
"""Operator converter for unstack."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
num = op.attr("num")
out = _op.split(x, num, axis=axis)
for i, out_i in enumerate(out):
out_i = _op.squeeze(out_i, axis=[axis])
g.add_node(op.output("Y")[i], out_i)


def convert_unique(g, op, block):
"""Operator converter for unique."""

x = g.get_node(op.input("X")[0])
ndim = len(infer_shape(x))
assert ndim == 1, "Only support 1D Tensor for PaddlePaddle's unique"
is_sorted = op.attr("is_sorted")
return_counts = op.attr("return_counts")
return_index = op.attr("return_index")
return_inverse = op.attr("return_inverse")
if return_counts:
[unique, indices, inverse_indices, num_uniq, counts] = _op.unique(
x, is_sorted=is_sorted, return_counts=True
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.strided_slice(counts, begin=[0], end=num_uniq, slice_mode="size")
indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size")
counts_sliced = _op.cast(counts_sliced, "int64")
g.add_node(op.output("Counts")[0], counts_sliced)
else:
[unique, indices, inverse_indices, num_uniq] = _op.unique(
x, is_sorted=is_sorted, return_counts=False
)
unique_sliced = _op.strided_slice(unique, begin=[0], end=num_uniq, slice_mode="size")
indices_sliced = _op.strided_slice(indices, begin=[0], end=num_uniq, slice_mode="size")

inverse_indices = _op.cast(inverse_indices, "int64")
indices_sliced = _op.cast(indices_sliced, "int64")
g.add_node(op.output("Out")[0], unique_sliced)
if return_index:
g.add_node(op.output("Indices")[0], indices_sliced)
if return_inverse:
g.add_node(op.output("Index")[0], inverse_indices)


def convert_where(g, op, block):
"""Operator converter for where."""

Expand Down Expand Up @@ -2107,6 +2153,8 @@ def convert_where(g, op, block):
"tile": convert_tile,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
"unstack": convert_unstack,
"unique": convert_unique,
"where": convert_where,
"where_index": convert_nonzero,
}
Expand Down
48 changes: 47 additions & 1 deletion tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def get_paddle_model(func, input_spec):
global PADDLE_TEST_DATA_ROOT_PATH
model_path = Path(PADDLE_TEST_DATA_ROOT_PATH, "model")
paddle.jit.save(func, str(model_path), input_spec=input_spec)
paddle.jit.save(func, "/paddle/pr_for_tvm/0905/tvm/inference_model_test/inference", input_spec=input_spec)
baseline_model = paddle.jit.load(str(model_path))

shutil.rmtree(str(PADDLE_TEST_DATA_ROOT_PATH))
Expand Down Expand Up @@ -1824,6 +1823,52 @@ def unstack3(x):
verify_model(unstack3, input_data=[input_data])


@tvm.testing.uses_gpu
def test_forward_unique():
@paddle.jit.to_static
def unique1(x):
return paddle.unique(x)

@paddle.jit.to_static
def unique2(x):
return paddle.unique(x, return_index=True, return_inverse=False, return_counts=False)

@paddle.jit.to_static
def unique3(x):
return paddle.unique(x, return_index=False, return_inverse=True, return_counts=False)

@paddle.jit.to_static
def unique4(x):
return paddle.unique(x, return_index=False, return_inverse=False, return_counts=True)

@paddle.jit.to_static
def unique5(x):
return paddle.unique(x, return_index=True, return_inverse=True, return_counts=False)

@paddle.jit.to_static
def unique6(x):
return paddle.unique(x, return_index=False, return_inverse=True, return_counts=True)

@paddle.jit.to_static
def unique7(x):
return paddle.unique(x, return_index=True, return_inverse=False, return_counts=True)

@paddle.jit.to_static
def unique8(x):
return paddle.unique(x, return_index=True, return_inverse=True, return_counts=True)

input_data = np.array([2, 3, 3, 1, 5, 3])
input_data = paddle.to_tensor(input_data)
verify_model(unique1, input_data=[input_data], input_shape=[[6]])
verify_model(unique2, input_data=[input_data], input_shape=[[6]])
verify_model(unique3, input_data=[input_data], input_shape=[[6]])
verify_model(unique4, input_data=[input_data], input_shape=[[6]])
verify_model(unique5, input_data=[input_data], input_shape=[[6]])
verify_model(unique6, input_data=[input_data], input_shape=[[6]])
verify_model(unique7, input_data=[input_data], input_shape=[[6]])
verify_model(unique8, input_data=[input_data], input_shape=[[6]])


@tvm.testing.uses_gpu
def test_forward_zeros():
@paddle.jit.to_static
Expand Down Expand Up @@ -1944,6 +1989,7 @@ def forward(self, x):
test_forward_tile()
test_forward_conv_transpose()
test_forward_unstack()
test_forward_unique()
test_forward_math()
test_forward_zeros()
test_forward_where()
Expand Down

0 comments on commit 2355add

Please sign in to comment.