Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Change all InferMeta functions #39222

Merged
merged 20 commits into from
Jan 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/framework/custom_kernel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,13 @@ TEST(CustomKernel, custom_kernel_dot) {
kernel_context.EmplaceBackAttr(fake_attr_int64_vec);
kernel_context.EmplaceBackAttr(fake_attr_int_vec);

auto out_meta = pten::DotInferMeta(dense_x->meta(), dense_y->meta());
auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(out_meta));
pten::DenseTensorMeta());

pten::MetaTensor meta_out(dense_out.get());
pten::DotInferMeta(*dense_x, *dense_y, &meta_out);
kernel_context.EmplaceBackOutput(dense_out.get()); // idx:0 index:[0,1)

// fake_input_vec: idx:1, index:[1,3)
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,14 @@ class CompatMetaTensor : public pten::MetaTensor {
}
}

void share_meta(const MetaTensor& meta_tensor) override {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
share_lod(meta_tensor);
}

private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
Expand Down
35 changes: 15 additions & 20 deletions paddle/pten/api/lib/api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/meta_tensor.h"

namespace paddle {
namespace experimental {
Expand All @@ -44,44 +45,38 @@ inline std::unique_ptr<std::vector<pten::DenseTensor>> TensorToDenseTensor(

/* ----------------- for infer_meta --------------------- */

inline const pten::DenseTensorMeta& GetDenseTensorMeta(
const pten::DenseTensor& tensor) {
return tensor.meta();
inline pten::MetaTensor MakeMetaTensor(const pten::DenseTensor& tensor) {
return pten::MetaTensor(tensor);
}

inline std::vector<pten::DenseTensorMeta> GetDenseTensorMeta(
inline std::vector<pten::MetaTensor> MakeMetaTensor(
const std::vector<pten::DenseTensor>& tensors) {
std::vector<pten::DenseTensorMeta> metas;
metas.reserve(tensors.size());
std::vector<pten::MetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto& t : tensors) {
metas.push_back(t.meta());
meta_tensors.emplace_back(t);
}
return metas;
return meta_tensors;
}

/* ------------------ for output ----------------------- */

inline pten::DenseTensor* SetKernelOutput(const pten::DenseTensorMeta& meta,
Backend backend,
Tensor* out) {
inline pten::DenseTensor* SetKernelOutput(Backend backend, Tensor* out) {
auto dense_tensor = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)),
meta);
pten::DenseTensorMeta());
out->set_impl(dense_tensor);
return dense_tensor.get();
}

inline std::vector<pten::DenseTensor*> SetKernelOutput(
const std::vector<pten::DenseTensorMeta>& metas,
Backend backend,
std::vector<Tensor>* out) {
size_t n = metas.size();
out->reserve(n);
std::vector<pten::DenseTensor*> results(n);
for (size_t i = 0; i < n; ++i) {
size_t out_size, Backend backend, std::vector<Tensor>* out) {
out->reserve(out_size);
std::vector<pten::DenseTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
auto tensor_ptr = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<SharedStorage>(pten::TransToFluidPlace(backend)),
metas[i]);
pten::DenseTensorMeta());
results[i] = tensor_ptr.get();
out->emplace_back();
out->back().set_impl(tensor_ptr);
Expand Down
11 changes: 5 additions & 6 deletions paddle/pten/api/lib/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,19 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
kernel_context.EmplaceBackInput(dense_x.get());
kernel_context.EmplaceBackAttr(blocking);

// 4. InferMeta
auto out_meta = UnchangedInferMeta(dense_x->meta());

// 5. Prepare outputs
// 4. Prepare outputs & InferMeta
auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
pten::TransToFluidPlace(backend)),
std::move(out_meta));
pten::DenseTensorMeta());
pten::MetaTensor meta_out(dense_out.get());
pten::UnchangedInferMeta(*dense_x, &meta_out);
dense_out->mutable_data(pten::TransToFluidPlace(backend));
kernel_context.EmplaceBackOutput(dense_out.get());
Tensor out;
out.set_impl(dense_out);

// 6. Call kernel
// 5. Call kernel
kernel(&kernel_context);

return out;
Expand Down
10 changes: 0 additions & 10 deletions paddle/pten/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ limitations under the License. */

namespace pten {

// TODO(chenweihang): add other flags if needed
struct MetaConfig {
bool is_runtime{true};

MetaConfig() = default;

// supporting implicit construction is easier to use
MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT
};

class InferMetaContext {
public:
InferMetaContext() = default;
Expand Down
28 changes: 20 additions & 8 deletions paddle/pten/core/meta_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void MetaTensor::set_dims(const DDim& dims) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->dims =
dims;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name()));
}
}
Expand All @@ -43,7 +43,7 @@ void MetaTensor::set_dtype(DataType dtype) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->dtype = dtype;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting dtype for `%s`.", tensor_->type_info().name()));
}
}
Expand All @@ -53,7 +53,7 @@ void MetaTensor::set_layout(DataLayout layout) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))
->layout = layout;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported settting layout for `%s`.", tensor_->type_info().name()));
}
}
Expand All @@ -63,18 +63,30 @@ void MetaTensor::share_lod(const MetaTensor& meta_tensor) {
DenseTensorUtils::GetMutableMeta(static_cast<DenseTensor*>(tensor_))->lod =
meta_tensor.lod();
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported share lod inplace for `%s`.",
tensor_->type_info().name()));
PADDLE_THROW(
pten::errors::Unimplemented("Unsupported sharing lod inplace for `%s`.",
tensor_->type_info().name()));
}
}

const LoD& MetaTensor::lod() const {
if (pten::DenseTensor::classof(tensor_)) {
return static_cast<DenseTensor*>(tensor_)->lod();
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported setting dims for `%s`.", tensor_->type_info().name()));
PADDLE_THROW(pten::errors::Unimplemented("Unsupported getting lod of `%s`.",
tensor_->type_info().name()));
}
}

void MetaTensor::share_meta(const MetaTensor& meta_tensor) {
if (pten::DenseTensor::classof(tensor_)) {
set_dims(meta_tensor.dims());
set_dtype(meta_tensor.dtype());
set_layout(meta_tensor.layout());
share_lod(meta_tensor);
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported sharing meta for `%s`.", tensor_->type_info().name()));
}
}

Expand Down
21 changes: 19 additions & 2 deletions paddle/pten/core/meta_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,26 @@ limitations under the License. */

namespace pten {

// TODO(chenweihang): add other flags if needed
struct MetaConfig {
bool is_runtime{true};

MetaConfig() = default;

// supporting implicit construction is easier to use
MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT
};

class MetaTensor {
public:
explicit MetaTensor(TensorBase* tensor) : tensor_(tensor) {}

MetaTensor() = default;

// supporting implicit construction is easier to use
MetaTensor(TensorBase* tensor) : tensor_(tensor) {} // NOLINT
MetaTensor(const TensorBase& tensor) // NOLINT
: tensor_(const_cast<TensorBase*>(&tensor)) {}
MetaTensor(TensorBase& tensor) : tensor_(&tensor) {} // NOLINT

MetaTensor(const MetaTensor&) = default;
MetaTensor(MetaTensor&&) = default;
MetaTensor& operator=(const MetaTensor&) = delete;
Expand All @@ -42,7 +57,9 @@ class MetaTensor {
virtual void set_dims(const DDim& dims);
virtual void set_dtype(DataType dtype);
virtual void set_layout(DataLayout layout);

virtual void share_lod(const MetaTensor& meta_tensor);
virtual void share_meta(const MetaTensor& meta_tensor);

private:
// Because the lod in compiletime and runtime is different,
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/infermeta/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils infermeta_utils)
cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils)
cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils meta_tensor infermeta_utils)
cc_library(backward_infermeta SRCS backward.cc DEPS meta_tensor convert_utils)
16 changes: 9 additions & 7 deletions paddle/pten/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ limitations under the License. */

namespace pten {

std::tuple<DenseTensorMeta, DenseTensorMeta> MatmulGradInferMeta(
const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
const DenseTensorMeta& out_grad_meta,
bool transpose_x,
bool transpose_y) {
return std::make_tuple(x_meta, y_meta);
void MatmulGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out_grad_meta,
bool transpose_x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy) {
dx->share_meta(x);
dy->share_meta(y);
}

} // namespace pten
16 changes: 9 additions & 7 deletions paddle/pten/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ limitations under the License. */
#pragma once

#include <tuple>
#include "paddle/pten/core/tensor_meta.h"

#include "paddle/pten/core/meta_tensor.h"

namespace pten {

std::tuple<DenseTensorMeta, DenseTensorMeta> MatmulGradInferMeta(
const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
const DenseTensorMeta& out_grad_meta,
bool transpose_x,
bool transpose_y);
void MatmulGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out_grad_meta,
bool transpose_x,
bool transpose_y,
MetaTensor* dx,
MetaTensor* dy);

} // namespace pten
Loading