forked from PaddlePaddle/CINN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[0D-Tensor] Add ExpandZeroDimPass, support 0D input and output (Paddl…
…ePaddle#1428) * [0D-Tensor] Add ExpandZeroDimPass, support 0D input and output * Change graph pass to program pass, enable unit test of pass * polish codes * change year, 2021->2023 * Change CI paddle version 2.4.2->develop * disable paddle temporarily * Restore paddle version in this PR * add python unittest of pass
- Loading branch information
Showing
12 changed files
with
426 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
// Copyright (c) 2023 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <functional> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
#include "cinn/frontend/net_builder.h" | ||
#include "cinn/frontend/program_pass.h" | ||
#include "glog/logging.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
namespace pass { | ||
|
||
class ExpandZeroDimPass : public ProgramPass { | ||
public: | ||
using ProgramPass::ProgramPass; | ||
|
||
protected: | ||
void ApplyImpl(Program* program, | ||
const std::unordered_set<std::string>& fetch_ids, | ||
const common::Target& target) override { | ||
NetBuilder builder("expand_zero_dim_builder"); | ||
for (auto var : program->GetInputs()) { | ||
if (var->shape.empty()) { | ||
var->shape.push_back(1); | ||
} | ||
builder.CreateInput(var); | ||
} | ||
for (int i = 0; i < program->size(); ++i) { | ||
auto& instr = (*program)[i]; | ||
for (auto& input : instr->inputs) { | ||
if (input->shape.empty()) { | ||
VLOG(4) << "Change input 0D-Tensor " << input->id << " to 1D-Tensor"; | ||
input->shape.push_back(1); | ||
} | ||
} | ||
for (auto& output : instr->outputs) { | ||
if (output->shape.empty()) { | ||
VLOG(4) << "Change output 0D-Tensor " << output->id << " to 1D-Tensor"; | ||
output->shape.push_back(1); | ||
} | ||
} | ||
builder.AppendInstruction(instr); | ||
} | ||
*program = builder.Build(); | ||
} | ||
|
||
void Clear() override {} | ||
}; | ||
|
||
} // namespace pass | ||
} // namespace frontend | ||
} // namespace cinn | ||
|
||
CINN_REGISTER_HELPER(ExpandZeroDim) { | ||
CINN_REGISTER_PROGRAM_PASS(ExpandZeroDim, cinn::frontend::pass::ExpandZeroDimPass); | ||
|
||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
// Copyright (c) 2023 CINN Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <cfloat> | ||
|
||
#include "cinn/cinn.h" | ||
#include "cinn/frontend/decomposer/test_helper.h" | ||
#include "cinn/frontend/net_builder.h" | ||
#include "cinn/frontend/optimize.h" | ||
#include "cinn/frontend/pass/use_program_pass.h" | ||
#include "cinn/frontend/program_pass.h" | ||
#include "cinn/frontend/syntax.h" | ||
#include "cinn/hlir/framework/graph.h" | ||
#include "cinn/hlir/framework/graph_compiler.h" | ||
#include "cinn/hlir/framework/pass.h" | ||
#include "cinn/hlir/op/use_ops.h" | ||
#include "cinn/hlir/pass/use_pass.h" | ||
#include "cinn/utils/data_util.h" | ||
|
||
namespace cinn { | ||
namespace frontend { | ||
|
||
int GetSize(std::vector<int>& shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()); } | ||
|
||
std::unordered_map<std::string, std::vector<float>> GetInputRandom(const std::vector<Variable>&& inputs) { | ||
std::unordered_map<std::string, std::vector<float>> input_data; | ||
for (auto input : inputs) { | ||
input_data[input->id] = std::vector<float>(GetSize(input->shape)); | ||
InitRandomVector<float>(&input_data[input->id], input_data[input->id].size(), 0.0f, 1.0f, 1e-3); | ||
} | ||
|
||
return input_data; | ||
} | ||
|
||
std::unordered_map<std::string, hlir::framework::Tensor> RunWithProgram( | ||
const Program& program, | ||
const Target& target, | ||
const std::unordered_map<std::string, std::vector<float>>& input_data, | ||
const std::unordered_set<std::string>& fetch_ids) { | ||
auto graph = std::make_shared<hlir::framework::Graph>(program, fetch_ids, target); | ||
auto scope = hlir::framework::BuildScope(target, graph); | ||
|
||
hlir::framework::ApplyPasses(graph.get(), {"InferShape"}); | ||
hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses()); | ||
VLOG(1) << "graph:\n" << graph->Visualize(); | ||
hlir::framework::GraphCompiler gc(target, scope, graph); | ||
auto runtime_program = gc.Build(); | ||
for (auto& data : input_data) { | ||
scope->Var<hlir::framework::Tensor>(data.first); | ||
auto tensor = scope->GetTensor(data.first); | ||
CopyFromVector(data.second, tensor, target); | ||
} | ||
runtime_program->Execute(); | ||
|
||
std::unordered_map<std::string, hlir::framework::Tensor> outputs; | ||
for (auto id : fetch_ids) { | ||
auto tensor = scope->GetTensor(id); | ||
outputs[id] = tensor; | ||
} | ||
return outputs; | ||
} | ||
|
||
TEST(ExpandZeroDimPass, expand_zero_dim_1) { | ||
NetBuilder builder("expand_zero_dim_1"); | ||
auto x = builder.CreateInput(Float(32), {}, "x"); | ||
auto y = builder.CreateInput(Float(32), {}, "y"); | ||
auto out = builder.Add(x, y); | ||
auto program = builder.Build(); | ||
auto target = common::DefaultTarget(); | ||
|
||
size_t origin_size = program.size(); | ||
VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program; | ||
/* | ||
Program { | ||
Var(var_1: shape=[], dtype=float32) | ||
Var(y: shape=[], dtype=float32) | ||
Var(x: shape=[], dtype=float32) | ||
var_1 = elementwise_add(x, y, axis=-1) | ||
} | ||
*/ | ||
ProgramPass::Apply(&program, {}, target, {"ExpandZeroDim"}); | ||
size_t pass_size = program.size(); | ||
VLOG(1) << "Program after ExpandZeroDimPass:\n" << program; | ||
/* | ||
Program { | ||
Var(var_1: shape=[1], dtype=float32) | ||
Var(y: shape=[1], dtype=float32) | ||
Var(x: shape=[1], dtype=float32) | ||
var_1 = elementwise_add(x, y, axis=-1) | ||
} | ||
*/ | ||
auto input_data = GetInputRandom({x, y}); | ||
auto fetch_ids = {out->id}; | ||
auto outputs = RunWithProgram(program, target, input_data, fetch_ids); | ||
for (auto iter : outputs) { | ||
// output var_1: shape=[1] | ||
ASSERT_EQ(iter.second->shape().data().size(), 1); | ||
} | ||
} | ||
|
||
TEST(ExpandZeroDimPass, expand_zero_dim_2) { | ||
NetBuilder builder("expand_zero_dim_1"); | ||
auto x = builder.CreateInput(Float(32), {3, 5}, "x"); | ||
auto y = builder.CreateInput(Float(32), {}, "y"); | ||
auto out = builder.Add(x, y); | ||
auto program = builder.Build(); | ||
auto target = common::DefaultTarget(); | ||
|
||
size_t origin_size = program.size(); | ||
VLOG(1) << "Program Before ExpandZeroDimPass:\n" << program; | ||
/* | ||
Program { | ||
Var(var_1: shape=[3, 5], dtype=float32) | ||
Var(y: shape=[], dtype=float32) | ||
Var(x: shape=[3, 5], dtype=float32) | ||
var_1 = elementwise_add(x, y, axis=-1) | ||
} | ||
*/ | ||
ProgramPass::Apply(&program, {}, target, {"ExpandZeroDim"}); | ||
size_t pass_size = program.size(); | ||
VLOG(1) << "Program after ExpandZeroDimPass:\n" << program; | ||
/* | ||
Program { | ||
Var(var_1: shape=[3, 5], dtype=float32) | ||
Var(y: shape=[1], dtype=float32) | ||
Var(x: shape=[3, 5], dtype=float32) | ||
var_1 = elementwise_add(x, y, axis=-1) | ||
} | ||
*/ | ||
auto input_data = GetInputRandom({x, y}); | ||
auto fetch_ids = {out->id}; | ||
auto outputs = RunWithProgram(program, target, input_data, fetch_ids); | ||
for (auto iter : outputs) { | ||
// output var_1: shape=[3, 5] | ||
ASSERT_EQ(iter.second->shape().data().size(), 2); | ||
} | ||
} | ||
|
||
} // namespace frontend | ||
} // namespace cinn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.