Skip to content

Commit

Permalink
Added fluid dependencies to Eager Dygraph #2 (#37556)
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 authored Nov 26, 2021
1 parent a9608f6 commit 471fa1e
Show file tree
Hide file tree
Showing 8 changed files with 994 additions and 0 deletions.
258 changes: 258 additions & 0 deletions paddle/fluid/eager/legacy/infer_var_type_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/legacy/tensor_helper.h"
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/pten/api/all.h"
#include "paddle/pten/include/core.h"

namespace egr {

// infer var type context for imperative mode
class TensorRuntimeInferVarTypeContext
: public paddle::framework::InferVarTypeContext {
public:
TensorRuntimeInferVarTypeContext(
const NameTensorMap& inputs, const NameTensorMap& outputs,
const paddle::framework::AttributeMap& attrs_map,
const paddle::framework::AttributeMap& default_attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
default_attrs_(default_attrs_map) {}

virtual ~TensorRuntimeInferVarTypeContext() {}

paddle::framework::Attribute GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);

if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Can not find [%s] in attributes.", name));
}
}

return it->second;
}

bool HasInput(const std::string& name) const override {
auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0);
}

bool HasOutput(const std::string& name) const override {
auto it = outputs_.find(name);
return (it != outputs_.end() && it->second.size() > 0);
}

size_t InputSize(const std::string& name) const {
return inputs_.at(name).size();
}

const std::string& InputVarName(const std::string& name,
const int index = 0) const {
// TODO(jiabin): Support this usage inputs_.at(name)[index]->Name()
auto it = inputs_.find(name);
PADDLE_ENFORCE_NE(it, inputs_.end(),
paddle::platform::errors::PreconditionNotMet(
"Can not find [%s] in Input", name));
return inputs_.at(name)[index]->name();
}

bool InputTypeAnyOf(
const std::string& name,
paddle::framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::any_of(
inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<egr::EagerTensor>& var) {
return paddle::framework::ToVarType(var->Var().Type()) == type;
});
}

bool InputTypeAllOf(
const std::string& name,
paddle::framework::proto::VarType::Type type) const override {
auto& inputs = inputs_.at(name);
return std::all_of(
inputs.begin(), inputs.end(),
[&type](const std::shared_ptr<egr::EagerTensor>& var) {
return paddle::framework::ToVarType(var->Var().Type()) == type;
});
}

void SyncTypeAndDataType(const std::string& input_name,
const std::string& output_name,
int index = 0) override {
auto in_tensor = inputs_.at(input_name)[index];
auto out_tensor = outputs_.at(output_name)[index];
if (in_tensor != out_tensor) {
this->SetTensorType(
out_tensor, paddle::framework::ToVarType(in_tensor->Var().Type()));
}
}

void SetOutputType(const std::string& name,
paddle::framework::proto::VarType::Type type,
int index = 0) override {
if (index == paddle::framework::ALL_ELEMENTS) {
for (auto& item : outputs_.at(name)) {
this->SetTensorType(item, type);
}
} else {
auto& var = outputs_.at(name)[index];
this->SetTensorType(var, type);
}
}

void SetTensorType(std::shared_ptr<egr::EagerTensor> out,
paddle::framework::proto::VarType::Type type) {
switch (type) {
case paddle::framework::proto::VarType::LOD_TENSOR: {
out->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Cannot found var type: %s while running runtime InferVarType",
paddle::framework::ToTypeName(type)));
}
}
}

paddle::framework::proto::VarType::Type GetInputType(
const std::string& name, const int& index = 0) const override {
return paddle::framework::ToVarType(inputs_.at(name)[index]->Var().Type());
}

paddle::framework::proto::VarType::Type GetOutputType(
const std::string& name, const int& index = 0) const override {
return paddle::framework::ToVarType(outputs_.at(name)[index]->Var().Type());
}

paddle::framework::proto::VarType::Type GetInputDataType(
const std::string& name, const int& index = 0) const override {
return inputs_.at(name)[index]
->Var()
.Get<paddle::framework::LoDTensor>()
.type();
}

void SetOutputDataType(const std::string& name,
paddle::framework::proto::VarType::Type type,
int index = 0) override {
// TODO(jiabin): It seems doesn't make sense to set data_type in EagerMode.
}

bool IsDygraph() const override { return true; }

protected:
bool HasVar(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"HasVar is not supported in runtime InferVarType"));
}

const std::vector<std::string>& InputVars(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"InputVars is not supported in runtime InferVarType"));
}

const std::vector<std::string>& OutputVars(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"OutputVars is not supported in runtime InferVarType"));
}

paddle::framework::proto::VarType::Type GetVarType(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}

void SetVarType(const std::string& name,
paddle::framework::proto::VarType::Type type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}

paddle::framework::proto::VarType::Type GetVarDataType(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}

void SetVarDataType(const std::string& name,
paddle::framework::proto::VarType::Type type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not manipulate var in runtime InferVarType"));
}

std::vector<paddle::framework::proto::VarType::Type> GetVarDataTypes(
const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"GetVarDataTypes is not supported in runtime InferVarType"));
}

void SetVarDataTypes(
const std::string& name,
const std::vector<paddle::framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"SetVarDataTypes is not supported in runtime InferVarType"));
}

std::vector<int64_t> GetVarShape(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}

void SetVarShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}

int32_t GetVarLoDLevel(const std::string& name) const override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}

void SetVarLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}

private:
const NameTensorMap& inputs_;
const NameTensorMap& outputs_;
const paddle::framework::AttributeMap& attrs_;
const paddle::framework::AttributeMap& default_attrs_;
};

} // namespace egr
Loading

1 comment on commit 471fa1e

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 471fa1e Nov 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #1 Commit ID: 471fa1e contains failed CI.

🔹 Failed: PR-CI-Py3

Unknown Failed
2021-11-26 10:51:56 + CI_SKIP_CPP_TEST=
2021-11-26 10:51:56 + '[' ON == ON ']'
2021-11-26 10:51:56 + python3.7 /workspace/Paddle/tools/get_pr_ut.py
2021-11-26 10:52:01 Traceback (most recent call last):
2021-11-26 10:52:01 File "/workspace/Paddle/tools/get_pr_ut.py", line 402, in
2021-11-26 10:52:01 pr_checker.init()
2021-11-26 10:52:01 File "/workspace/Paddle/tools/get_pr_ut.py", line 69, in init
2021-11-26 10:52:01 if last_commit.message.find('test=allcase') != -1:
2021-11-26 10:52:01 AttributeError: 'NoneType' object has no attribute 'message'
2021-11-26 10:52:01 + EXCODE=1
2021-11-26 10:52:01 + echo 'EXCODE: 1'
2021-11-26 10:52:01 EXCODE: 1
2021-11-26 10:52:01 + echo 'ipipe_log_param_EXCODE: 1'
2021-11-26 10:52:01 ipipe_log_param_EXCODE: 1
2021-11-26 10:52:01 + [[ 1 -eq 0 ]]
2021-11-26 10:52:01 + set +x
2021-11-26 10:52:01 + exit 1
2021-11-26 10:52:01 {build code state=1}
2021-11-26 10:52:11 kill agent BUILD_CODE_FAIL

🔹 Failed: PR-CI-Coverage

Unknown Failed
2021-11-26 10:56:18 + export CI_SKIP_CPP_TEST=
2021-11-26 10:56:18 + CI_SKIP_CPP_TEST=
2021-11-26 10:56:18 + '[' ON == ON ']'
2021-11-26 10:56:18 + python3.7 /paddle/tools/get_pr_ut.py
2021-11-26 10:56:21 Traceback (most recent call last):
2021-11-26 10:56:21 File "/paddle/tools/get_pr_ut.py", line 402, in
2021-11-26 10:56:21 pr_checker.init()
2021-11-26 10:56:21 File "/paddle/tools/get_pr_ut.py", line 69, in init
2021-11-26 10:56:21 if last_commit.message.find('test=allcase') != -1:
2021-11-26 10:56:21 AttributeError: 'NoneType' object has no attribute 'message'
2021-11-26 10:56:21 + EXCODE=1
2021-11-26 10:56:21 + echo 1
2021-11-26 10:56:21 1
2021-11-26 10:56:21 + echo 'ipipe_log_param_EXCODE: 1'
2021-11-26 10:56:21 ipipe_log_param_EXCODE: 1
2021-11-26 10:56:21 + '[' 1 -ne 0 ']'
2021-11-26 10:56:21 + '[' 1 -ne 9 ']'
2021-11-26 10:56:21 + exit 1
2021-11-26 10:56:21 {build code state=1}

Please sign in to comment.