Skip to content

Commit

Permalink
[0D-Tensor] CINN supports transpose, add special case to expand_zero_…
Browse files Browse the repository at this point in the history
…dim_pass (PaddlePaddle#55379)
  • Loading branch information
jiahy0825 authored and wz1qqx committed Jul 31, 2023
1 parent b549913 commit 4edfead
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
58 changes: 49 additions & 9 deletions paddle/cinn/frontend/pass/expand_zero_dim_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,73 @@ class ExpandZeroDimPass : public ProgramPass {
const std::unordered_set<std::string>& fetch_ids,
const common::Target& target) override {
NetBuilder builder("expand_zero_dim_builder");
for (auto var : program->GetInputs()) {
if (var->shape.empty()) {
var->shape.push_back(1);
}
builder.CreateInput(var);
}
for (int i = 0; i < program->size(); ++i) {
auto& instr = (*program)[i];
if (instr->op_type == "transpose") {
builder.AppendInstruction(HandleTranspose(instr));
continue;
}
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
VLOG(4) << "Change input 0D-Tensor " << input->id << " to 1D-Tensor";
VLOG(4) << "Change " << instr->op_type << "'s input 0D-Tensor "
<< input->id << " to 1D-Tensor";
input->shape.push_back(1);
}
}
for (auto& output : instr->outputs) {
if (output->shape.empty()) {
VLOG(4) << "Change output 0D-Tensor " << output->id
<< " to 1D-Tensor";
VLOG(4) << "Change " << instr->op_type << "'s output 0D-Tensor "
<< output->id << " to 1D-Tensor";
output->shape.push_back(1);
}
}
builder.AppendInstruction(instr);
}
for (auto var : program->GetInputs()) {
if (var->shape.empty()) {
VLOG(4) << "Change program's input 0D-Tensor " << var->id
<< " to 1D-Tensor";
var->shape.push_back(1);
}
builder.CreateInput(var);
}
*program = builder.Build();
}

void Clear() override {}

private:
// Before: out-0D = transpose(x-0D, [])
// After: out-1D = transpose(x-1D, [1])
Instruction HandleTranspose(const Instruction& instr) {
Instruction new_instr = instr;
bool has_0d_input = false;
for (auto& input : new_instr->inputs) {
if (input->shape.empty()) {
VLOG(4) << "Change transpose's input 0D-Tensor " << input->id
<< " to 1D-Tensor";
input->shape.push_back(1);
has_0d_input = true;
}
}
for (auto& output : new_instr->outputs) {
if (output->shape.empty()) {
VLOG(4) << "Change transpose's output 0D-Tensor " << output->id
<< " to 1D-Tensor";
output->shape.push_back(1);
}
}
if (has_0d_input) {
std::vector<int32_t> axis =
new_instr.GetAttrs<std::vector<int32_t>>("axis");
CHECK(axis.empty()) << "transpose's axis should be empty when inputs "
"0D-Tensor! Please check setting.\n";
axis.push_back(0);
VLOG(4) << "Change Transpose's attribute axis from [] to [1]";
new_instr.SetAttr<std::vector<int32_t>>("axis", axis);
}
return new_instr;
}
};

} // namespace pass
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,8 @@ std::vector<framework::shape_t> InferShapeForTranspose(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
std::vector<framework::shape_t> result;
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty())
<< "The input's shape size is 0! Please check again.";
CHECK(!inputs_shape.empty())
<< "The input's shape is empty! Please check again.";
if (attrs.find("axis") != attrs.end()) {
auto axis = absl::get<std::vector<int>>(attrs.at("axis"));
CHECK_EQ(axis.size(), inputs_shape[0].size())
Expand Down
38 changes: 38 additions & 0 deletions test/cinn/ops/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,44 @@ def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestTransposeOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()

def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = ()

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.transpose(x, [])

self.paddle_outputs = [out]

def build_cinn_program(self, target):
builder = NetBuilder("transpose_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.transpose(x, [])

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])

self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)

def test_check_results(self):
self.check_outputs_and_grads()


@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
Expand Down

0 comments on commit 4edfead

Please sign in to comment.