forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from zyfncg/my_drr
[DRR] Add Basic Class
- Loading branch information
Showing
4 changed files
with
475 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/ir/pattern_rewrite/drr/api/drr_pass_context.h" | ||
|
||
#include <glog/logging.h> | ||
#include "paddle/ir/pattern_rewrite/drr/pattern_graph.h" | ||
|
||
namespace ir { | ||
namespace drr { | ||
|
||
const Op& DrrPassContext::SourceOpPattern( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}) { | ||
owned_ops_.push_back(std::make_shared<drr::Op>( | ||
op_type, attributes, source_pattern_graph_.get())); | ||
return *owned_ops_.back(); | ||
} | ||
|
||
const drr::Tensor& DrrPassContext::SourceTensorPattern( | ||
const std::string& tensor_id) { | ||
return source_pattern_graph_->AddTensor( | ||
std::make_shared<drr::Tensor>(tensor_id, source_pattern_graph_.get())); | ||
} | ||
|
||
const Op& DrrPassContext::ResultOpPattern( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}) { | ||
owned_ops_.push_back(std::make_shared<drr::Op>( | ||
op_type, attributes, result_pattern_graph_.get())); | ||
return *owned_ops_.back(); | ||
} | ||
|
||
const drr::Tensor& DrrPassContext::SourceTensorPattern( | ||
const std::string& tensor_id) { | ||
return result_pattern_graph_->AddTensor( | ||
std::make_shared<drr::Tensor>(tensor_id, result_pattern_graph_.get())); | ||
} | ||
|
||
void Op::operator()(const Tensor& arg, const Tensor* out) const { | ||
std::vector<std::weak_ptr<const Tensor>> inputs{arg.shared_from_this()}; | ||
std::vector<std::weak_ptr<const Tensor>> outputs{out->shared_from_this()}; | ||
pattern_graph_->AddOpCall( | ||
std::make_shared<OpCall>(shared_from_this(), inputs, outputs)); | ||
} | ||
|
||
Tensor& Op::operator()(const Tensor& arg) const { | ||
std::vector<std::weak_ptr<const Tensor>> inputs{arg.shared_from_this()}; | ||
auto& out = pattern_graph_->AddTmpTensor(std::make_shared<Tensor>( | ||
"tmp_" + op_type_name_ + std::to_string(count++), pattern_graph_)); | ||
std::vector<std::weak_ptr<const Tensor>> outputs{out.shared_from_this()}; | ||
pattern_graph_->AddOpCall( | ||
std::make_shared<OpCall>(shared_from_this(), inputs, outputs)); | ||
return out; | ||
} | ||
|
||
Tensor& Op::operator()() const { | ||
std::vector<std::weak_ptr<const Tensor>> inputs{}; | ||
auto& out = pattern_graph_->AddTmpTensor(std::make_shared<Tensor>( | ||
"tmp_" + op_type_name_ + std::to_string(count++), pattern_graph_)); | ||
std::vector<std::weak_ptr<const Tensor>> outputs{out.shared_from_this()}; | ||
pattern_graph_->AddOpCall( | ||
std::make_shared<OpCall>(shared_from_this(), inputs, outputs)); | ||
return out; | ||
} | ||
|
||
int64_t Op::count = 0; | ||
|
||
void Tensor::operator=(Tensor& other) const { // NOLINT | ||
// The two tensor must be in the same pattern graph. | ||
CHECK(this->pattern_graph_ == other.pattern_graph_); | ||
if (other.tensor_id_.substr(0, 4) == "tmp_") { | ||
pattern_graph_->UpdateTmpTensor(other.tensor_id_, this->tensor_id_); | ||
} | ||
} | ||
|
||
} // namespace drr | ||
} // namespace ir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <functional> | ||
#include <memory> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
namespace ir { | ||
namespace drr { | ||
|
||
class Op; | ||
class Tensor; | ||
class OpCall; | ||
class Constrain; | ||
class SourcePattern; | ||
class ResultPattern; | ||
class PatternGraph; | ||
|
||
using id_type = std::string; | ||
|
||
class DrrPassContext : public std::enable_shared_from_this<DrrPassContext> { | ||
public: | ||
DrrPassContext() = default; | ||
~DrrPassContext() = default; | ||
|
||
drr::SourcePattern SourcePattern() { return drr::SourcePattern(this); } | ||
|
||
private: | ||
friend class drr::SourcePattern; | ||
friend class drr::ResultPattern; | ||
|
||
const Op& SourceOpPattern( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}); | ||
const drr::Tensor& SourceTensorPattern(const std::string& tensor_id); | ||
|
||
const Op& ResultOpPattern( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}); | ||
const drr::Tensor& ResultTensorPattern(const std::string& tensor_id); | ||
|
||
std::shared_ptr<SourcePatternGraph> source_pattern_graph_; | ||
std::vector<std::unique_ptr<const Constrain>> constraints_; | ||
std::shared_ptr<ResultPatternGraph> result_pattern_graph_; | ||
|
||
std::vector<std::shared_ptr<const drr::Op>> owned_ops_; | ||
}; | ||
|
||
class DrrPass { | ||
public: | ||
virtual void operator()(DrrPassContext* ctx) const; | ||
}; | ||
|
||
class Attribute { | ||
public: | ||
explicit Attribute(const std::string& id) : attr_id_(id) {} | ||
|
||
enum class Type { OP_ATTR, TENSOR_SHAPE, TENSOR_DTYPE }; | ||
|
||
Type type() const { return type_; } | ||
|
||
private: | ||
std::string attr_id_; | ||
}; | ||
|
||
class TensorShape : public Attribute { | ||
public: | ||
explicit TensorShape(const std::string& tensor_id) | ||
: Attribute(tensor_id + "_shape_"), tensor_id_(tensor_id) {} | ||
|
||
private: | ||
std::string tensor_id_; | ||
}; | ||
|
||
class Op : public std::enable_shared_from_this<Op> { | ||
public: | ||
void operator()(const Tensor& arg, const Tensor* out) const; | ||
|
||
Tensor& operator()() const; | ||
|
||
Tensor& operator()(const Tensor& arg) const; | ||
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; | ||
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const | ||
// Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const | ||
// Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& | ||
// operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2, | ||
// const Tensor& arg3, const Tensor& arg4) const; | ||
// void operator()(const std::vector<Tensor>& args, const | ||
// std::vector<Tensor*>& outputs) const; | ||
|
||
private: | ||
friend class SourcePattern; | ||
|
||
Op(const std::string& op_type_name, | ||
const std::unordered_map<std::string, Attribute>& attributes, | ||
PatternGraph* pattern_graph) | ||
: op_type_name_(op_type_name), | ||
attributes_(attributes), | ||
pattern_graph_(pattern_graph) {} | ||
|
||
static int64_t count; | ||
std::string op_type_name_; | ||
std::unordered_map<std::string, Attribute> attributes_; | ||
PatternGraph* pattern_graph_; | ||
}; | ||
|
||
class Tensor : public std::enable_shared_from_this<Tensor> { | ||
public: | ||
const std::string& DebugName() const; | ||
|
||
TensorShape shape() const { return TensorShape(id()); } | ||
|
||
Tensor& operator=(const Tensor& other) = delete; | ||
|
||
void operator=(Tensor& other) const; // NOLINT | ||
|
||
const id_type& id() const { return tensor_id_; } | ||
|
||
void set_id(const id_type& id) { tensor_id_ = id; } | ||
|
||
std::weak_ptr<OpCall> producer() const { return producer_; } | ||
|
||
void set_producer(std::weak_ptr<OpCall> producer) { producer_ = producer; } | ||
|
||
const std::unordered_set<std::weak_ptr<const OpCall>>& consumers() const { | ||
return consumers_; | ||
} | ||
|
||
void set_consumables( | ||
const std::unordered_set<std::weak_ptr<const OpCall>>& consumers) { | ||
consumers_ = consumers; | ||
} | ||
|
||
void AddConsumer(std::weak_ptr<const OpCall> consumer) { | ||
consumers_.insert(consumer); | ||
} | ||
|
||
private: | ||
friend class DrrPassContext; | ||
friend class Op; | ||
|
||
// explicit Tensor(const id_type& tensor_id) : tensor_id_(tensor_id) {} | ||
|
||
Tensor(const id_type& tensor_id, PatternGraph* pattern_graph) | ||
: tensor_id_(tensor_id), pattern_graph_(pattern_graph) {} | ||
|
||
id_type tensor_id_; | ||
std::weak_ptr<OpCall> producer_; | ||
std::unordered_set<std::weak_ptr<const OpCall>> consumers_; | ||
PatternGraph* pattern_graph_; | ||
}; | ||
|
||
class OpCall : public std::enable_shared_from_this<OpCall> { | ||
public: | ||
OpCall(std::weak_ptr<const Op> op, | ||
const std::vector<std::weak_ptr<const Tensor>>& inputs, | ||
const std::vector<std::weak_ptr<const Tensor>>& outputs) | ||
: op_(op), inputs_(inputs), outputs_(outputs) {} | ||
|
||
const std::vector<std::weak_ptr<const Tensor>>& inputs() const { | ||
return inputs_; | ||
} | ||
|
||
const std::vector<std::weak_ptr<const Tensor>>& outputs() const { | ||
return outputs_; | ||
} | ||
|
||
private: | ||
id_type op_call_id_; | ||
std::weak_ptr<const Op> op_; | ||
std::vector<std::weak_ptr<const Tensor>> inputs_; | ||
std::vector<std::weak_ptr<const Tensor>> outputs_; | ||
}; | ||
|
||
class ResultPattern { | ||
public: | ||
const drr::Op& Op( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}) { | ||
return ctx_->ResultOpPattern(op_type, attributes); | ||
} | ||
|
||
const drr::Tensor& Tensor(const std::string& tensor_id) { | ||
return ctx_->ResultTensorPattern(tensor_id); | ||
} | ||
|
||
Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } | ||
|
||
private: | ||
friend class SourcePattern; | ||
|
||
explicit ResultPattern(DrrPassContext* ctx) : ctx_(ctx) {} | ||
|
||
DrrPassContext* ctx_; | ||
}; | ||
|
||
class SourcePattern { | ||
public: | ||
ResultPattern ResultPattern() const { return ResultPattern(ctx_); } | ||
|
||
const drr::Op& Op( | ||
const std::string& op_type, | ||
const std::unordered_map<std::string, Attribute>& attributes = {}) { | ||
return ctx_->SourceOpPattern(op_type, attributes); | ||
} | ||
|
||
const drr::Tensor& Tensor(const std::string& tensor_id) { | ||
return ctx_->SourceTensorPattern(tensor_id); | ||
} | ||
|
||
Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } | ||
|
||
private: | ||
friend class DrrPassContext; | ||
explicit SourcePattern(DrrPassContext* ctx) : ctx_(ctx) {} | ||
DrrPassContext* ctx_; | ||
}; | ||
|
||
} // namespace drr | ||
} // namespace ir |
Oops, something went wrong.