Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[0D-Tensor] Support Unary OP (#1478)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 authored May 26, 2023
1 parent cf2e3f0 commit 67aa13a
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/reciprocal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(const framework::NodeAttr &att

std::vector<framework::shape_t> InferShapeForReciprocal(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
CHECK(!inputs_shape.empty()) << "The input's shape size is empty! Please check again.";
std::vector<framework::shape_t> res{inputs_shape[0]};
return res;
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ std::shared_ptr<OpStrategy> StrategyForRelu(const framework::NodeAttr &attrs,

std::vector<framework::shape_t> InferShapeForRelu(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again.";
CHECK(!inputs_shape.empty()) << "The input's shape is empty! Please check again.";
std::vector<framework::shape_t> res{inputs_shape[0]};
return res;
}
Expand Down
124 changes: 119 additions & 5 deletions python/tests/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def build_paddle_program(self, target):
self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("elementwise_op")
builder = NetBuilder("binary_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x")
y = builder.create_input(
Expand Down Expand Up @@ -179,11 +179,11 @@ def setUp(self):
def init_dtype(self):
self.dtype = dtype

def paddle_func(self, x, y):
return fn_paddle(x, y)
def paddle_func(self, *args):
return fn_paddle(*args)

def cinn_func(self, builder, x, y):
return eval(fn_cinn)(x, y)
def cinn_func(self, builder, *args):
return eval(fn_cinn)(*args)

cls_name = "{}_{}".format(parent.__name__, test_name)
TestClass.__name__ = cls_name
Expand Down Expand Up @@ -379,5 +379,119 @@ def cinn_func(self, builder, x, y):
"builder.less_equal",
dtype="int64")


######################
#### TestUnaryOp ####
######################
@OpTestTool.skip_if(not is_compiled_with_cuda(),
"x86 test will be skipped due to timeout.")
class TestUnaryOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.init_dtype()
self.init_input()

def init_dtype(self):
self.dtype = "float32"

def init_input(self):
self.inputs = {
"x": np.random.uniform(0.0, 1.0, []).astype(self.dtype),
}
self.target_shape = ()

def paddle_func(self, x):
return paddle.sqrt(x)

def cinn_func(self, builder, x):
return builder.sqrt(x)

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = self.paddle_func(x)

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("unary_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x")
out = self.cinn_func(builder, x)

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]],
[out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


create_unit_test(TestUnaryOp, "tanh", paddle.tanh, "builder.tanh")
create_unit_test(TestUnaryOp, "relu", paddle.nn.functional.relu,
"builder.relu")
create_unit_test(TestUnaryOp, "gelu", paddle.nn.functional.gelu,
"builder.gelu")
create_unit_test(TestUnaryOp, "sigmoid", paddle.nn.functional.sigmoid,
"builder.sigmoid")
create_unit_test(TestUnaryOp, "exp", paddle.exp, "builder.exp")
create_unit_test(TestUnaryOp, "erf", paddle.erf, "builder.erf")
create_unit_test(TestUnaryOp, "rsqrt", paddle.rsqrt, "builder.rsqrt")
create_unit_test(TestUnaryOp, "log", paddle.log, "builder.log")
create_unit_test(TestUnaryOp, "log2", paddle.log2, "builder.log2")
create_unit_test(TestUnaryOp, "log10", paddle.log10, "builder.log10")
create_unit_test(TestUnaryOp, "floor", paddle.floor, "builder.floor")
create_unit_test(TestUnaryOp, "ceil", paddle.ceil, "builder.ceil")
create_unit_test(TestUnaryOp, "round", paddle.round, "builder.round")
create_unit_test(TestUnaryOp, "trunc", paddle.trunc, "builder.trunc")
create_unit_test(TestUnaryOp, "sin", paddle.sin, "builder.sin")
create_unit_test(TestUnaryOp, "cos", paddle.cos, "builder.cos")
create_unit_test(TestUnaryOp, "tan", paddle.tan, "builder.tan")
create_unit_test(TestUnaryOp, "sinh", paddle.sinh, "builder.sinh")
create_unit_test(TestUnaryOp, "cosh", paddle.cosh, "builder.cosh")
create_unit_test(TestUnaryOp, "asin", paddle.asin, "builder.asin")
create_unit_test(TestUnaryOp, "acos", paddle.acos, "builder.acos")
create_unit_test(TestUnaryOp, "atan", paddle.atan, "builder.atan")
create_unit_test(TestUnaryOp, "asinh", paddle.asinh, "builder.asinh")
create_unit_test(TestUnaryOp, "atanh", paddle.atanh, "builder.atanh")
create_unit_test(TestUnaryOp, "isnan", paddle.isnan, "builder.is_nan")
create_unit_test(TestUnaryOp, "isfinite", paddle.isfinite, "builder.is_finite")
create_unit_test(TestUnaryOp, "isinf", paddle.isinf, "builder.is_inf")
create_unit_test(
TestUnaryOp,
"logical_not",
paddle.logical_not,
"builder.logical_not",
dtype="bool")
create_unit_test(
TestUnaryOp,
"bitwise_not",
paddle.bitwise_not,
"builder.bitwise_not",
dtype="int64")
create_unit_test(TestUnaryOp, "negative", paddle.neg, "builder.negative")
create_unit_test(TestUnaryOp, "sign", paddle.sign, "builder.sign")
create_unit_test(TestUnaryOp, "abs", paddle.abs, "builder.abs")
create_unit_test(TestUnaryOp, "reciprocal", paddle.reciprocal,
"builder.reciprocal")


# acosh requires input value > 1.0, specific init_input instead of using create_unit_test
class TestUnaryOp_acosh(TestUnaryOp):
def init_input(self):
self.inputs = {
"x": np.random.uniform(1.0, 10.0, []).astype(self.dtype),
}
self.target_shape = ()

def paddle_func(self, x):
return paddle.acosh(x)

def cinn_func(self, builder, x):
return builder.acosh(x)


if __name__ == "__main__":
unittest.main()

0 comments on commit 67aa13a

Please sign in to comment.