Skip to content

Commit

Permalink
move get expected kernel args into pten (#38825)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored Jan 10, 2022
1 parent 657b674 commit 3a23c1a
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 68 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const {
pt_kernel_signature_.reset(
new KernelSignature(std::move(this->GetExpectedPtenKernelArgs(ctx))));
VLOG(6) << KernelSignatureToString(*pt_kernel_signature_.get());
VLOG(6) << *pt_kernel_signature_.get();

kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License. */
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/flat_hash_map.h"

#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/include/core.h"

namespace paddle {
Expand Down Expand Up @@ -438,6 +439,45 @@ class ExecutionContext {
const RuntimeContext& ctx_;
};

// TODO(chenweihang): split impl based OpProto or Dygraph if needed
class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
public:
explicit ExecutionArgumentMappingContext(const ExecutionContext& ctx)
: ctx_(ctx) {}

bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name);
}

bool HasOutput(const std::string& name) const override {
return ctx_.HasOutput(name);
}

bool HasAttr(const std::string& name) const override {
return ctx_.HasAttr(name);
}

size_t InputSize(const std::string& name) const override {
return ctx_.InputSize(name);
}

size_t OutputSize(const std::string& name) const override {
return ctx_.OutputSize(name);
}

bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::Tensor>() ||
ctx_.InputVar(name)->IsType<framework::LoDTensor>();
}

bool IsSelectedRowsInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::SelectedRows>();
}

private:
const ExecutionContext& ctx_;
};

template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

Expand Down
10 changes: 0 additions & 10 deletions paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,5 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames());
}

std::string KernelSignatureToString(const KernelSignature& signature) {
std::stringstream os;
os << "Kernel Signature - name: " << signature.name
<< "; inputs: " << string::join_strings(std::get<0>(signature.args), ", ")
<< "; attributes: "
<< string::join_strings(std::get<1>(signature.args), ", ") << "; outputs: "
<< string::join_strings(std::get<2>(signature.args), ", ");
return os.str();
}

} // namespace framework
} // namespace paddle
26 changes: 4 additions & 22 deletions paddle/fluid/framework/pten_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"

namespace paddle {
namespace framework {

using KernelSignature = pten::KernelSignature;

/* Kernel Key translate */

OpKernelType TransPtenKernelKeyToOpKernelType(
Expand All @@ -42,24 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(

/* Kernel Args parse */

struct KernelSignature {
std::string name;
KernelArgsTuple args;

KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
};

// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
Expand Down Expand Up @@ -88,7 +72,5 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};

std::string KernelSignatureToString(const KernelSignature& signature);

} // namespace framework
} // namespace paddle
5 changes: 0 additions & 5 deletions paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,5 @@ using InferShapeFN = std::function<void(InferShapeContext*)>;
using InplacePair = std::unordered_map<std::string, std::string>;
using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;

// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;

} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);
VLOG(6) << framework::KernelSignatureToString(pt_kernel_signature);
VLOG(6) << pt_kernel_signature;

auto pt_kernel_name = pt_kernel_signature.name;
auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key);
Expand Down
16 changes: 3 additions & 13 deletions paddle/fluid/operators/scale_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <string>
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/ops/compat/scale_args_fn.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -73,19 +74,8 @@ class ScaleOp : public framework::OperatorWithKernel {

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>() ||
ctx.InputVar("X")->IsType<framework::Tensor>()) {
std::string scale_attr;
if (ctx.HasInput("ScaleTensor")) {
scale_attr = "ScaleTensor";
} else {
scale_attr = "scale";
}
return framework::KernelSignature(
"scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"});
}
// TODO(chenweihang): support other cases after selected rows added
return framework::KernelSignature("scale.unregistered", {}, {}, {});
framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::ScaleOpArgumentMapping(arg_mapping_ctx);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ add_subdirectory(ops)
add_subdirectory(tests)

# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context infermeta)
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
Expand Down
3 changes: 2 additions & 1 deletion paddle/pten/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ endif()

cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)

cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)

Expand Down
60 changes: 60 additions & 0 deletions paddle/pten/core/arg_map_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/arg_map_context.h"

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"

namespace pten {

OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() {
static OpArgumentMappingFnMap g_op_arg_mapping_fn_map;
return g_op_arg_mapping_fn_map;
}

bool OpArgumentMappingFnMap::Has(const std::string& op_type) const {
return fn_map_.find(op_type) != fn_map_.end();
}

const ArgumentMappingFn& OpArgumentMappingFnMap::Get(
const std::string& op_type) const {
auto it = fn_map_.find(op_type);
PADDLE_ENFORCE_NE(
it,
fn_map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s argument mapping funciton is not registered.",
op_type));
return it->second;
}

void OpArgumentMappingFnMap::Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn) {
name_map_.emplace(op_type, api_name);
fn_map_.emplace(op_type, fn);
}

std::ostream& operator<<(std::ostream& os, KernelSignature signature) {
os << "Kernel Signature - name: " << signature.name << "; inputs: "
<< paddle::string::join_strings(std::get<0>(signature.args), ", ")
<< "; attributes: "
<< paddle::string::join_strings(std::get<1>(signature.args), ", ")
<< "; outputs: "
<< paddle::string::join_strings(std::get<2>(signature.args), ", ");
return os;
}

} // namespace pten
86 changes: 86 additions & 0 deletions paddle/pten/core/arg_map_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <ostream>
#include <string>
#include <tuple>

#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"

namespace pten {

// tuple(input_names, attr_names, output_names)
using KernelArgsTuple = std::tuple<paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>,
paddle::SmallVector<std::string>>;

// TODO(chenweihang): Add more methods if needed in future
class ArgumentMappingContext {
public:
virtual ~ArgumentMappingContext() = default;

virtual bool HasInput(const std::string& name) const = 0;
virtual bool HasOutput(const std::string& name) const = 0;
virtual bool HasAttr(const std::string& name) const = 0;

virtual size_t InputSize(const std::string& name) const = 0;
virtual size_t OutputSize(const std::string& name) const = 0;

virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};

struct KernelSignature {
std::string name;
KernelArgsTuple args;

KernelSignature() = default;
KernelSignature(std::string&& kernel_name,
paddle::SmallVector<std::string>&& inputs,
paddle::SmallVector<std::string>&& attrs,
paddle::SmallVector<std::string>&& outputs)
: name(std::move(kernel_name)),
args(std::make_tuple(inputs, attrs, outputs)) {}
KernelSignature(const std::string& kernel_name,
const paddle::SmallVector<std::string>& inputs,
const paddle::SmallVector<std::string>& attrs,
const paddle::SmallVector<std::string>& outputs)
: name(kernel_name), args(std::make_tuple(inputs, attrs, outputs)) {}
};

std::ostream& operator<<(std::ostream& os, KernelSignature signature);

using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&);

class OpArgumentMappingFnMap {
public:
static OpArgumentMappingFnMap& Instance();

bool Has(const std::string& op_type) const;

const ArgumentMappingFn& Get(const std::string& op_type) const;

void Emplace(const std::string& op_type,
const std::string api_name,
ArgumentMappingFn fn);

private:
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> fn_map_;
};

} // namespace pten
13 changes: 0 additions & 13 deletions paddle/pten/core/kernel_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,4 @@ using KernelArgsDefFn = void (*)(Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def);

// Multiple kernels of the same operation are distinguished by the difference
// of the overload name. For the convenience of reuse, we define some overload
// naming strings for the naming of the kernel

// For kernels that contains dynamic tensor attribute and it need to be always
// on host device, such as `ScaleTensor`
constexpr char kContainHostTensorSuffix[] = "host";

// For kernels with SelectedRowsTensor input and output
constexpr char kContainSelectedRowsSuffix[] = "sr";

// For kernels with intermediate output
constexpr char kContainMidOutputTensorSuffix[] = "mid";
} // namespace pten
2 changes: 1 addition & 1 deletion paddle/pten/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ endif()
# pten depends all pten kernel targets
set_property(GLOBAL PROPERTY PTEN_KERNELS "")

set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory convert_utils)
set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
Expand Down
Loading

0 comments on commit 3a23c1a

Please sign in to comment.