diff --git a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc index 92dfecab6dd09..a4212555fc682 100644 --- a/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc +++ b/paddle/cinn/frontend/pass/expand_zero_dim_pass.cc @@ -34,33 +34,73 @@ class ExpandZeroDimPass : public ProgramPass { const std::unordered_set& fetch_ids, const common::Target& target) override { NetBuilder builder("expand_zero_dim_builder"); - for (auto var : program->GetInputs()) { - if (var->shape.empty()) { - var->shape.push_back(1); - } - builder.CreateInput(var); - } for (int i = 0; i < program->size(); ++i) { auto& instr = (*program)[i]; + if (instr->op_type == "transpose") { + builder.AppendInstruction(HandleTranspose(instr)); + continue; + } for (auto& input : instr->inputs) { if (input->shape.empty()) { - VLOG(4) << "Change input 0D-Tensor " << input->id << " to 1D-Tensor"; + VLOG(4) << "Change " << instr->op_type << "'s input 0D-Tensor " + << input->id << " to 1D-Tensor"; input->shape.push_back(1); } } for (auto& output : instr->outputs) { if (output->shape.empty()) { - VLOG(4) << "Change output 0D-Tensor " << output->id - << " to 1D-Tensor"; + VLOG(4) << "Change " << instr->op_type << "'s output 0D-Tensor " + << output->id << " to 1D-Tensor"; output->shape.push_back(1); } } builder.AppendInstruction(instr); } + for (auto var : program->GetInputs()) { + if (var->shape.empty()) { + VLOG(4) << "Change program's input 0D-Tensor " << var->id + << " to 1D-Tensor"; + var->shape.push_back(1); + } + builder.CreateInput(var); + } *program = builder.Build(); } void Clear() override {} + + private: + // Before: out-0D = transpose(x-0D, []) + // After: out-1D = transpose(x-1D, [1]) + Instruction HandleTranspose(const Instruction& instr) { + Instruction new_instr = instr; + bool has_0d_input = false; + for (auto& input : new_instr->inputs) { + if (input->shape.empty()) { + VLOG(4) << "Change transpose's input 0D-Tensor " << input->id + << " to 1D-Tensor"; + input->shape.push_back(1); + has_0d_input = true; + } + } + for (auto& output : new_instr->outputs) { + if (output->shape.empty()) { + VLOG(4) << "Change transpose's output 0D-Tensor " << output->id + << " to 1D-Tensor"; + output->shape.push_back(1); + } + } + if (has_0d_input) { + std::vector axis = + new_instr.GetAttrs>("axis"); + CHECK(axis.empty()) << "transpose's axis should be empty when inputs " + "0D-Tensor! Please check setting.\n"; + axis.push_back(0); + VLOG(4) << "Change Transpose's attribute axis from [] to [1]"; + new_instr.SetAttr>("axis", axis); + } + return new_instr; + } }; } // namespace pass diff --git a/paddle/cinn/hlir/op/transform.cc b/paddle/cinn/hlir/op/transform.cc index 78d24cba7035d..7df8e440c6838 100644 --- a/paddle/cinn/hlir/op/transform.cc +++ b/paddle/cinn/hlir/op/transform.cc @@ -1137,8 +1137,8 @@ std::vector InferShapeForTranspose( const std::vector &inputs_shape, const framework::AttrMapType &attrs) { std::vector result; - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) - << "The input's shape size is 0! Please check again."; + CHECK(!inputs_shape.empty()) + << "The input's shape is empty! Please check again."; if (attrs.find("axis") != attrs.end()) { auto axis = absl::get>(attrs.at("axis")); CHECK_EQ(axis.size(), inputs_shape[0].size()) diff --git a/test/cinn/ops/test_zero_dim_tensor.py b/test/cinn/ops/test_zero_dim_tensor.py index 3ba7ac3bc7591..16c110b2298a2 100644 --- a/test/cinn/ops/test_zero_dim_tensor.py +++ b/test/cinn/ops/test_zero_dim_tensor.py @@ -630,6 +630,44 @@ def test_check_results(self): self.check_outputs_and_grads() +@OpTestTool.skip_if( + not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." +) +class TestTransposeOp(OpTest): + def setUp(self): + np.random.seed(2023) + self.dtype = "float32" + self.init_input() + + def init_input(self): + self.inputs = { + "x": np.random.randint(-10, 10, []).astype(self.dtype), + } + self.target_shape = () + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) + out = paddle.transpose(x, []) + + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("transpose_op") + x = builder.create_input( + cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x" + ) + out = builder.transpose(x, []) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) + + self.cinn_outputs = res + self.assertEqual(res[0].shape, self.target_shape) + + def test_check_results(self): + self.check_outputs_and_grads() + + @OpTestTool.skip_if( not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." )