Skip to content

Commit

Permalink
add dist meta tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Aug 29, 2023
1 parent 0603212 commit 5259b14
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 101 deletions.
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ collect_srcs(
dist_mapper.cc
reshard_utils.cc
dist_tensor.cc
dist_meta_tensor.cc
inferspmd_utils.cc
reshard_function.cc
reshard_split_functor.cc
Expand Down
51 changes: 51 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"

namespace phi {
namespace distributed {

phi::DDim DistMetaTensor::dims() const {
// member values in tensor_ have higher priority than those in DistMetaTensor
if (tensor_ != nullptr) {
PADDLE_ENFORCE_EQ(this->is_dist(),
true,
phi::errors::InvalidArgument(
"The current MetaTensor doesn't contains "
"DistTensor when call `dist_attr` method."));
return MetaTensor::dims();
} else {
return dims_;
}
}

const distributed::TensorDistAttr& DistMetaTensor::dist_attr() const {
// member values in tensor_ have higher priority than those in DistMetaTensor
if (tensor_ != nullptr) {
PADDLE_ENFORCE_EQ(this->is_dist(),
true,
phi::errors::InvalidArgument(
"The current MetaTensor doesn't contains "
"DistTensor when call `dist_attr` method."));
return static_cast<phi::distributed::DistTensor*>(tensor_)->dist_attr();
} else {
return dist_attr_;
}
}

} // namespace distributed
} // namespace phi
68 changes: 68 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/meta_tensor.h"

namespace phi {
namespace distributed {

class DistMetaTensor : public MetaTensor {
public:
// supporting implicit construction is easier to use
DistMetaTensor(TensorBase* tensor) // NOLINT
: MetaTensor(tensor) {}
DistMetaTensor(const TensorBase& tensor) // NOLINT
: MetaTensor(tensor) {}
DistMetaTensor(const TensorBase* tensor) // NOLINT
: MetaTensor(tensor) {}
DistMetaTensor(TensorBase& tensor) // NOLINT
: MetaTensor(tensor) {}
// For static mode only
DistMetaTensor(const phi::DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {}

DistMetaTensor(DistMetaTensor&&) = default;
DistMetaTensor& operator=(DistMetaTensor&&) = default;
DistMetaTensor(const DistMetaTensor&) = default;
DistMetaTensor& operator=(const DistMetaTensor&) = default;

virtual ~DistMetaTensor() = default;

DDim dims() const override;

const distributed::TensorDistAttr& dist_attr() const;

private:
/**
* Note: When using the semi-automatic parallel segmentation derivation rules
* of the static graph, in order to facilitate the packaging of the input
* parameters of the construction, the DistMetaTensor is inherited and
* encapsulated, and the class members dims_ and dist_attr_ are added to it.
*
* The information contained in these two members is also in the tensor of the
* meta_tensor of the base class, and there is redundancy.
*
* We need to pay attention when using it to ensure the consistency.
* These two members are read-only, and their values cannot be changed
* after construction. To change their values, they need to be set
* directly in tensor_*/
phi::DDim dims_;
TensorDistAttr dist_attr_;
};

} // namespace distributed
} // namespace phi
4 changes: 2 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ limitations under the License. */
namespace phi {
namespace distributed {

void InferSpmdContext::EmplaceBackInput(MetaTensor input) {
void InferSpmdContext::EmplaceBackInput(DistMetaTensor input) {
inputs_.emplace_back(std::move(input));
}

void InferSpmdContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr));
}

const MetaTensor& InferSpmdContext::InputAt(size_t idx) const {
const DistMetaTensor& InferSpmdContext::InputAt(size_t idx) const {
return inputs_.at(idx);
}

Expand Down
14 changes: 7 additions & 7 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/utils/any.h"
#include "paddle/utils/flat_hash_map.h"
Expand All @@ -39,14 +39,14 @@ class InferSpmdContext {
public:
InferSpmdContext() = default;
InferSpmdContext(
paddle::small_vector<MetaTensor, phi::kInputSmallVectorSize> inputs,
paddle::small_vector<DistMetaTensor, phi::kInputSmallVectorSize> inputs,
paddle::small_vector<Attribute, phi::kAttrSmallVectorSize> attrs)
: inputs_(std::move(inputs)), attrs_(std::move(attrs)) {}

void EmplaceBackInput(MetaTensor input);
void EmplaceBackInput(DistMetaTensor input);
void EmplaceBackAttr(Attribute attr);

const MetaTensor& InputAt(size_t idx) const;
const DistMetaTensor& InputAt(size_t idx) const;

template <typename AttrType>
AttrType AttrAt(size_t idx) const;
Expand All @@ -55,7 +55,7 @@ class InferSpmdContext {

private:
// Now we only need `inputs`, for backward, the `output` is passed as input
paddle::small_vector<MetaTensor, phi::kInputSmallVectorSize> inputs_;
paddle::small_vector<DistMetaTensor, phi::kInputSmallVectorSize> inputs_;
// Because the attribute arguments of dygraph do not have `attr name`,
// so we use vector instead of map
paddle::small_vector<Attribute, phi::kAttrSmallVectorSize> attrs_;
Expand Down Expand Up @@ -86,12 +86,12 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {

// TODO(chenweihang): support other input type later as needed
template <typename... Tail>
struct InferSpmdFnCallHelper<const MetaTensor&, Tail...> {
struct InferSpmdFnCallHelper<const DistMetaTensor&, Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferSpmd's Input should appear before Attributes.");
const MetaTensor& arg = ctx.InputAt(in_idx);
const DistMetaTensor& arg = ctx.InputAt(in_idx);
return InferSpmdFnCallHelper<Tail...>::template Call<in_idx + 1,
attr_idx>(
ctx, pargs..., arg);
Expand Down
13 changes: 0 additions & 13 deletions paddle/phi/core/meta_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/string_tensor.h"
Expand Down Expand Up @@ -278,15 +276,4 @@ const LoD& MetaTensor::lod() const {
}
}

/////////////// For Auto Parallel ////////////////

const distributed::TensorDistAttr& MetaTensor::dist_attr() const {
PADDLE_ENFORCE_EQ(
this->is_dist(),
true,
phi::errors::InvalidArgument("The current MetaTensor doesn't contains "
"DistTensor when call `dist_attr` method."));
return static_cast<phi::distributed::DistTensor*>(tensor_)->dist_attr();
}

} // namespace phi
8 changes: 1 addition & 7 deletions paddle/phi/core/meta_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"

namespace phi {
namespace distributed {
class TensorDistAttr;
} // namespace distributed

struct MetaConfig {
bool is_runtime{true};
Expand Down Expand Up @@ -92,9 +89,6 @@ class MetaTensor {

virtual bool is_same_tensor(const MetaTensor& meta_tensor) const;

// For auto parallel
const distributed::TensorDistAttr& dist_attr() const;

virtual operator unspecified_bool_type() const {
return tensor_ == nullptr ? 0 : unspecified_bool_true;
}
Expand All @@ -104,7 +98,7 @@ class MetaTensor {
protected:
static void unspecified_bool_true() {}

private:
protected:
// Because the lod in compiletime and runtime is different,
// so `LoD` cannot in public methods
const LoD& lod() const;
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/infermeta/spmd_rules/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void FillMatmulOperandNotation(const int x_ndim,

////////////////// InferMeta(Contains SPMD) Functions //////////////////

SpmdInfo MatmulSpmdInferForward(const MetaTensor& x,
const MetaTensor& y,
SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x,
const DistMetaTensor& y,
bool trans_x,
bool trans_y) {
// Step0: verify input args based on matmul logic
Expand Down Expand Up @@ -221,9 +221,9 @@ SpmdInfo MatmulSpmdInferForward(const MetaTensor& x,
return {{x_dist_attr_dst, y_dist_attr_dst}, {output_dist_attr_dst}};
}

SpmdInfo MatmulSpmdInferBackward(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out,
SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y) {
auto out_shape = phi::vectorize(out.dims());
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/spmd_rules/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ limitations under the License. */

#include <vector>

#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"
#include "paddle/phi/infermeta/binary.h"

namespace phi {
namespace distributed {

SpmdInfo MatmulSpmdInferForward(const MetaTensor& x,
const MetaTensor& y,
SpmdInfo MatmulSpmdInferForward(const DistMetaTensor& x,
const DistMetaTensor& y,
bool trans_x,
bool trans_y);

SpmdInfo MatmulSpmdInferBackward(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out,
SpmdInfo MatmulSpmdInferBackward(const DistMetaTensor& x,
const DistMetaTensor& y,
const DistMetaTensor& out,
bool trans_x,
bool trans_y);

Expand Down
Loading

0 comments on commit 5259b14

Please sign in to comment.