diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 8bc4d77a20a3..665c88629517 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -708,6 +708,17 @@ def convert_hard_swish(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_index_select(g, op, block): + """Operator converter for index_select.""" + + dim = op.attr("dim") + x = g.get_node(op.input("X")[0]) + index = g.get_node(op.input("Index")[0]) + out = _op.take(x, indices=index, axis=dim, mode="clip") + + g.add_node(op.output("Out")[0], out) + + def convert_layer_norm(g, op, block): """Operator converter for layer_norm.""" @@ -1055,6 +1066,33 @@ def convert_pow(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_norm(g, op, block): + """Operator converter for norm.""" + + x = g.get_node(op.input("X")[0]) + dtype = infer_type(x).checked_type.dtype + axis = op.attr("axis") + keepdim = op.attr("keepdim") + if op.attr("asvector"): + axis = None + order = op.attr("porder") + if order == np.inf: + out = _op.reduce.max(_op.abs(x), axis=axis, keepdims=keepdim) + elif order == np.NINF: + out = _op.reduce.min(_op.abs(x), axis=axis, keepdims=keepdim) + else: + reci_order = _expr.const(1.0 / order, dtype=dtype) + order = _expr.const(order) + out = _op.power( + _op.reduce.sum(_op.power(_op.abs(x), order), axis=axis, keepdims=keepdim), + reci_order, + ) + if op.attr("asvector") and not keepdim: + out = _op.expand_dims(out, axis=0) + + g.add_node(op.output("Out")[0], out) + + def convert_range(g, op, block): """Operator converter for range.""" @@ -1567,6 +1605,7 @@ def convert_unsqueeze(g, op, block): "greater_than": convert_elementwise_op, "hard_sigmoid": convert_hard_sigmoid, "hard_swish": convert_hard_swish, + "index_select": convert_index_select, "isinf": convert_unary_op, "isinf_v2": convert_unary_op, "layer_norm": convert_layer_norm, @@ -1576,6 +1615,7 @@ def convert_unsqueeze(g, op, block): "lookup_table": convert_lookup_table, "lookup_table_v2": convert_lookup_table, "log": convert_unary_op, + "log2": convert_unary_op, "log10": convert_unary_op, "log1p": convert_log1p, "logsumexp": convert_logsumexp, @@ -1589,6 +1629,7 @@ def convert_unsqueeze(g, op, block): "pad2d": convert_padding, "pad3d": convert_padding, "pow": convert_pow, + "p_norm": convert_norm, "range": convert_range, "reciprocal": convert_reciprocal, "reduce_all": convert_reduce, diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 655e12de9330..cc63da231804 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -155,6 +155,7 @@ def forward(self, inputs): "exp", "floor", "log", + "log2", "log10", "log1p", "numel", @@ -628,6 +629,26 @@ def ones_like2(inputs): @tvm.testing.uses_gpu +def test_forward_ones(): + @paddle.jit.to_static + def ones1(inputs): + ones = paddle.ones([1, 3, 10, 10]) + out = inputs + ones + return out + + @paddle.jit.to_static + def ones2(inputs): + shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") + ones = paddle.ones(shape) + out = inputs + ones + return out + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(ones1, input_data=input_data) + verify_model(ones2, input_data=input_data) + + def test_forward_elemwise(): class ElemwiseOp(nn.Layer): def __init__(self, op_name): @@ -739,6 +760,23 @@ def hard_swish(inputs): verify_model(hard_swish, input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_index_select(): + @paddle.jit.to_static + def index_select1(x, index): + return paddle.index_select(x, index) + + @paddle.jit.to_static + def index_select2(x, index): + return paddle.index_select(x, index, axis=1) + + input_shape = [3, 10] + input_data = paddle.rand(input_shape, dtype="float32") + index = paddle.to_tensor(np.array([0, 1, 1]).astype("int32")) + verify_model(index_select1, input_data=[input_data, index]) + verify_model(index_select2, input_data=[input_data, index]) + + @tvm.testing.uses_gpu def test_forward_isinf(): @paddle.jit.to_static @@ -943,6 +981,71 @@ def forward(self, inputs): ) +def test_forward_norm(): + class Norm1(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=None, keepdim=False) + + class Norm2(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=False) + + class Norm3(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("-inf"), axis=None, keepdim=True) + + class Norm4(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=[1, 2], keepdim=False) + + class Norm5(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float("inf"), axis=-1, keepdim=True) + + class Norm6(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(0.5), axis=1, keepdim=True) + + class Norm7(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(1), axis=None, keepdim=False) + + class Norm8(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(2.0), axis=1, keepdim=False) + + class Norm9(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(-0.5), axis=[1, 2], keepdim=False) + + class Norm10(nn.Layer): + @paddle.jit.to_static + def forward(self, inputs): + return paddle.norm(inputs, p=float(-2), axis=(1), keepdim=False) + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(Norm1(), input_data=input_data) + verify_model(Norm2(), input_data=input_data) + verify_model(Norm3(), input_data=input_data) + verify_model(Norm4(), input_data=input_data) + verify_model(Norm5(), input_data=input_data) + verify_model(Norm6(), input_data=input_data) + verify_model(Norm7(), input_data=input_data) + verify_model(Norm8(), input_data=input_data) + verify_model(Norm9(), input_data=input_data) + verify_model(Norm10(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_pool2d(): @paddle.jit.to_static @@ -1319,6 +1422,27 @@ def tile3(inputs, inputs2): verify_model(tile3, input_data=[input_data, input_data2]) +@tvm.testing.uses_gpu +def test_forward_zeros(): + @paddle.jit.to_static + def zeros1(inputs): + zeros = paddle.zeros([1, 3, 10, 10]) + out = inputs + zeros + return out + + @paddle.jit.to_static + def zeros2(inputs): + shape = paddle.to_tensor([1, 3, 10, 10], dtype="int32") + zeros = paddle.zeros(shape) + out = inputs + zeros + return out + + input_shape = [1, 3, 10, 10] + input_data = paddle.rand(input_shape, dtype="float32") + verify_model(zeros1, input_data=input_data) + verify_model(zeros2, input_data=input_data) + + if __name__ == "__main__": test_forward_add_subtract() test_forward_addmm() @@ -1338,12 +1462,14 @@ def tile3(inputs, inputs2): test_forward_expand() test_forward_flatten() test_forward_shape_full() + test_forward_ones() test_forward_ones_like() test_forward_gather_assign_value() test_forward_gather_nd() test_forward_gelu() test_forward_hard_sigmoid() test_forward_hard_swish() + test_forward_index_select() test_forward_interpolate() test_forward_isinf() test_forward_layer_norm() @@ -1353,6 +1479,7 @@ def tile3(inputs, inputs2): test_forward_matmul() test_forward_multiply() test_forward_nonzero() + test_forward_norm() test_forward_pool2d() test_forward_pad() test_forward_pow() @@ -1366,3 +1493,4 @@ def tile3(inputs, inputs2): test_forward_tile() test_forward_conv_transpose() test_forward_unary_op() + test_forward_zeros()