Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[0D-Tensor] Add ExpandZeroDimPass, support 0D input and output #1428

Merged
merged 9 commits into from
May 22, 2023
11 changes: 7 additions & 4 deletions cinn/frontend/decomposer/broadcast.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,13 @@ void elementwise_add(const Instruction& instr, const DecomposerContext& context)
void elementwise_add_grad(const Instruction& instr, const DecomposerContext& context) {
CHECK_EQ(instr->inputs.size(), 3UL) << " 3 input tensors for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 2UL) << "2 output tensors for " << instr->op_type;
auto dout = instr->inputs[0];
auto dx = instr->outputs[0];
auto dy = instr->outputs[1];
int axis = instr.GetAttrs<int>("axis");
auto dout = instr->inputs[0];
auto dx = instr->outputs[0];
auto dy = instr->outputs[1];
int axis = instr.GetAttrs<int>("axis");
if (axis < 0 && dx->shape.size() < dy->shape.size()) {
jiahy0825 marked this conversation as resolved.
Show resolved Hide resolved
LOG(FATAL) << "Please make sure x'rank greater than or equal to y'rank when axis = -1";
}
axis = axis >= 0 ? axis : dx->shape.size() - dy->shape.size();
auto* builder = context.builder();

Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/decomposer_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DecomposerContext {
// Map the new var to the original var.
void MapOutToOrigin(const Variable& new_var, const Variable& ori_var) const {
if (new_var->shape != ori_var->shape) {
LOG(FATAL) << "The output shape shoule be equal to the original. But received : " << new_var->id << ".shape=["
LOG(FATAL) << "The output shape should be equal to the original. But received : " << new_var->id << ".shape=["
<< utils::Join(new_var->shape, ", ") << "] and the original var " << ori_var->id << ".shape=["
<< utils::Join(ori_var->shape, ", ") << "].";
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Placeholder NetBuilder::CreateInput(const Type& type, const std::vector<int>& sh
}

Placeholder NetBuilder::CreateInput(const Variable& var) {
CHECK(!var->shape.empty()) << "The input's shape is not set yet";
VLOG_IF(4, var->shape.empty()) << "The input's shape is empty, Create 0D-Tensor for " << var->id;
CHECK(!var->type.is_unk()) << "The input's type is not set yet";
inputs_.push_back(var);
return Placeholder(var);
Expand Down
1 change: 1 addition & 0 deletions cinn/frontend/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace frontend {

OptimizeOptions DefaultTrainingOptimizeOptions() {
OptimizeOptions options;
options.program_passes.emplace_back("ExpandZeroDim");
options.program_passes.emplace_back("AutoCast");
options.program_passes.emplace_back("Decomposer");
options.program_passes.emplace_back("RemoveIdentity");
Expand Down
2 changes: 2 additions & 0 deletions cinn/frontend/pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ gather_srcs(cinnapi_src SRCS
fill_constant_folding.cc
cast_collapsing.cc
auto_cast.cc
expand_zero_dim_pass.cc
auto_broadcast.cc
)

Expand All @@ -32,3 +33,4 @@ endif()
cc_test(test_transpose_collapsing SRCS transpose_collapsing_test.cc DEPS cinncore)
cc_test(test_cast_collapsing SRCS cast_collapsing_test.cc DEPS cinncore)
cc_test(test_auto_cast SRCS auto_cast_test.cc DEPS cinncore)
cc_test(test_expand_zero_dim_pass SRCS expand_zero_dim_pass_test.cc DEPS cinncore)
73 changes: 73 additions & 0 deletions cinn/frontend/pass/expand_zero_dim_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/program_pass.h"
#include "glog/logging.h"

namespace cinn {
namespace frontend {
namespace pass {

class ExpandZeroDimPass : public ProgramPass {
public:
using ProgramPass::ProgramPass;

protected:
void ApplyImpl(Program* program,
const std::unordered_set<std::string>& fetch_ids,
const common::Target& target) override {
NetBuilder builder("expand_zero_dim_builder");
for (auto var : program->GetInputs()) {
if (var->shape.empty()) {
var->shape.push_back(1);
}
builder.CreateInput(var);
}
for (int i = 0; i < program->size(); ++i) {
auto& instr = (*program)[i];
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
VLOG(4) << "Change input 0D-Tensor " << input->id << " to 1D-Tensor";
input->shape.push_back(1);
}
}
for (auto& output : instr->outputs) {
if (output->shape.empty()) {
VLOG(4) << "Change output 0D-Tensor " << output->id << " to 1D-Tensor";
output->shape.push_back(1);
}
}
builder.AppendInstruction(instr);
}
*program = builder.Build();
}

void Clear() override {}
};

} // namespace pass
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(ExpandZeroDim) {
CINN_REGISTER_PROGRAM_PASS(ExpandZeroDim, cinn::frontend::pass::ExpandZeroDimPass);

return true;
}
157 changes: 157 additions & 0 deletions cinn/frontend/pass/expand_zero_dim_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include <cfloat>

#include "cinn/cinn.h"
#include "cinn/frontend/decomposer/test_helper.h"
#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/optimize.h"
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/op/use_ops.h"
#include "cinn/hlir/pass/use_pass.h"
#include "cinn/utils/data_util.h"

namespace cinn {
namespace frontend {

int GetSize(std::vector<int>& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); }

std::unordered_map<std::string, std::vector<float>> GetInputRandom(const std::vector<Variable>&& inputs) {
std::unordered_map<std::string, std::vector<float>> input_data;
for (auto input : inputs) {
input_data[input->id] = std::vector<float>(GetSize(input->shape));
InitRandomVector<float>(&input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3);
}

return input_data;
}

std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram(
const Program& program,
const Target& target,
const std::unordered_map<std::string, std::vector<float>>& input_data,
const std::unordered_set<std::string>& fetch_ids) {
auto graph = std::make_shared<hlir::framework::Graph>(program, fetch_ids, target);
auto scope = hlir::framework::BuildScope(target, graph);

hlir::framework::ApplyPasses(graph.get(), {"InferShape"});
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());
VLOG(1) << "graph:\n" << graph->Visualize();
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();
for (auto& data : input_data) {
scope->Var<hlir::framework::Tensor>(data.first);
auto tensor = scope->GetTensor(data.first);
CopyFromVector(data.second, tensor, target);
}
runtime_program->Execute();

std::unordered_map<std::string, hlir::framework::Tensor> outputs;
for (auto id : fetch_ids) {
auto tensor = scope->GetTensor(id);
outputs[id] = tensor;
}
return outputs;
}

TEST(ExpandZeroDimPass, expand_zero_dim_1) {
NetBuilder builder("expand_zero_dim_1");
auto x = builder.CreateInput(Float(32), {}, "x");
auto y = builder.CreateInput(Float(32), {}, "y");
auto out = builder.Add(x, y);
auto program = builder.Build();
auto target = common::DefaultTarget();

size_t origin_size = program.size();
VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program;
/*
Program {
Var(var_1: shape=[], dtype=float32)
Var(y: shape=[], dtype=float32)
Var(x: shape=[], dtype=float32)

var_1 = elementwise_add(x, y, axis=-1)
}
*/
ProgramPass::Apply(&program, {}, target, {"ExpandZeroDim"});
size_t pass_size = program.size();
VLOG(1) << "Program after ExpandZeroDimPass:\n" << program;
/*
Program {
Var(var_1: shape=[1], dtype=float32)
Var(y: shape=[1], dtype=float32)
Var(x: shape=[1], dtype=float32)

var_1 = elementwise_add(x, y, axis=-1)
}
*/
auto input_data = GetInputRandom({x, y});
auto fetch_ids = {out->id};
auto outputs = RunWithProgram(program, target, input_data, fetch_ids);
for (auto iter : outputs) {
// output var_1: shape=[1]
ASSERT_EQ(iter.second->shape().data().size(), 1);
}
}

TEST(ExpandZeroDimPass, expand_zero_dim_2) {
NetBuilder builder("expand_zero_dim_1");
auto x = builder.CreateInput(Float(32), {3, 5}, "x");
auto y = builder.CreateInput(Float(32), {}, "y");
auto out = builder.Add(x, y);
auto program = builder.Build();
auto target = common::DefaultTarget();

size_t origin_size = program.size();
VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program;
/*
Program {
Var(var_1: shape=[3, 5], dtype=float32)
Var(y: shape=[], dtype=float32)
Var(x: shape=[3, 5], dtype=float32)

var_1 = elementwise_add(x, y, axis=-1)
}
*/
ProgramPass::Apply(&program, {}, target, {"ExpandZeroDim"});
size_t pass_size = program.size();
VLOG(1) << "Program after ExpandZeroDimPass:\n" << program;
/*
Program {
Var(var_1: shape=[3, 5], dtype=float32)
Var(y: shape=[1], dtype=float32)
Var(x: shape=[3, 5], dtype=float32)

var_1 = elementwise_add(x, y, axis=-1)
}
*/
auto input_data = GetInputRandom({x, y});
auto fetch_ids = {out->id};
auto outputs = RunWithProgram(program, target, input_data, fetch_ids);
for (auto iter : outputs) {
// output var_1: shape=[3, 5]
ASSERT_EQ(iter.second->shape().data().size(), 2);
}
}

} // namespace frontend
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/frontend/pass/use_program_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "cinn/common/macros.h"

CINN_USE_REGISTER(ExpandZeroDim)
CINN_USE_REGISTER(AutoCast)
CINN_USE_REGISTER(Decomposer)
CINN_USE_REGISTER(DeadCodeEliminate)
Expand Down
Loading