Skip to content

Commit

Permalink
infershaped autogen (PR #1), test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Feb 8, 2022
1 parent eacfc1e commit 070fda4
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 32 deletions.
3 changes: 2 additions & 1 deletion paddle/infrt/naive/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc
)

cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
30 changes: 14 additions & 16 deletions paddle/infrt/naive/infershaped/elementwise_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"

// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
Expand All @@ -32,39 +33,36 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape();
}

static void ElementwiseAdd(const tensor::DenseHostTensor& a,
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}

// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};

void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&infershape_kernel_frame_builder);
BuildInferShapeCache(num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}

INFRT_KERNEL(ElementwiseAdd)(frame);
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
}
};

const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};

} // namespace naive
} // namespace infrt
14 changes: 14 additions & 0 deletions paddle/infrt/naive/infershaped/infershape_launchers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {

namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}

TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}

TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
Expand All @@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>());

host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
Expand Down
17 changes: 7 additions & 10 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
Expand All @@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
}

void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
}
}

bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;

bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
}
return changed;
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {

//! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices,
const uint16_t num_inputs);
void BuildInferShapeCache(const uint16_t num_inputs);

//! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices,
const uint16_t num_inputs) const;
bool IsShapeChanged(const uint16_t num_inputs) const;

// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
// limitations under the License.

#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"

#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"

namespace infrt {
namespace naive {

using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;

void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
Expand Down
77 changes: 77 additions & 0 deletions paddle/infrt/naive/infershaped/infershaped_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {
namespace infershaped {

using KeyType = const tensor::DenseHostTensor&;
using CountType = uint8_t;

constexpr CountType value(std::true_type) { return 1; }

constexpr CountType value(std::false_type) { return 0; }

template <typename T>
constexpr CountType value() {
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}

template <typename FirstArg>
constexpr CountType count(CountType num) {
return num;
}

template <typename FirstArg>
constexpr CountType count() {
return 0;
}

template <>
constexpr CountType count<KeyType>(CountType num) {
return num + 1;
}

template <>
constexpr CountType count<KeyType>() {
return 1;
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}

} // namespace infershaped

template <typename F>
struct InferShapeHelper;

template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};

} // namespace naive
} // namespace infrt

1 comment on commit 070fda4

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.