Skip to content

Commit

Permalink
[AutoParallel] Support operators have mixed inputs. (PaddlePaddle#57774)
Browse files Browse the repository at this point in the history
* [AutoParallel] Support operators have mixed inputs like DenseTensor and DistTensor.

* Polish code with review comments.
  • Loading branch information
GhostScreaming authored and Frida-a committed Oct 14, 2023
1 parent 362f3fb commit 19c7af2
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 5 deletions.
15 changes: 15 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def FindParsingFunctionFromAttributeType(atype):
" auto {} = {}(\"{}\", \"{}\", args, {}, {});\n"
)

CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_TEMPLATE = """
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh{inputs})) {{
ConvertAllInputsToDistTensor(mesh{inputs});
}}
"""

PARSE_PYTHON_C_ARGS_TEMPLATE = """ PyObject* {}_obj = PyTuple_GET_ITEM(args, {});
{} {} = {}({}_obj, \"{}\", {});
Expand Down Expand Up @@ -325,7 +331,9 @@ def GeneratePythonCFunction(self):
inplace_returns_pos_map = {}
# Generate Python-C Tensors Parsing Logic
get_eager_tensor_str = ""
input_names = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
input_names = input_names + ", " + name
if forward_inplace_map and name in forward_inplace_map.keys():
inplace_args_pos_map[name] = pos
is_optional = name in optional_inputs
Expand Down Expand Up @@ -375,6 +383,13 @@ def GeneratePythonCFunction(self):
"false",
)
)
# No inputs, skip convert to DistTensor
if len(input_names) > 0:
get_eager_tensor_str += (
CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_TEMPLATE.format(
inputs=input_names
)
)

if forward_inplace_map:
for name, (ttype, pos) in forward_outputs_position_map.items():
Expand Down
115 changes: 115 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2000,5 +2000,120 @@ void* UnPackHook::operator()(void* packed_value, void* other) {
return reinterpret_cast<void*>(ret);
}

/* ------------------ for auto parallel ----------------------- */

void DistTensorTypeParser::operator()(const Tensor& x) {
if (x.is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(x.impl())
->dist_attr()
.process_mesh());
result = true;
}
}

void DistTensorTypeParser::operator()(const paddle::optional<Tensor>& x) {
if (x) {
if (x.get_ptr()->is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
x.get_ptr()->impl())
->dist_attr()
.process_mesh());
result = true;
}
}
}

void DistTensorTypeParser::operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
if (t.is_dist_tensor()) {
*mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
result = true;
break;
}
}
}
}

void DistTensorTypeParser::operator()(
const paddle::optional<std::vector<Tensor>>& x) {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
if (t.is_dist_tensor()) {
*mesh = &(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
result = true;
break;
}
}
}
}
}

void DistTensorConverter::convert(Tensor* x) {
if (x->is_dist_tensor()) {
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->dist_attr()
.process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
"have the same mesh.",
x->name()));
return;
} else {
PADDLE_ENFORCE_EQ(
phi::DenseTensor::classof(x->impl().get()),
true,
platform::errors::InvalidArgument(
"Failed to convert input %s impl to phi::distributed::DistTensor "
"as it's not phi::DenseTensor.",
x->name()));
phi::distributed::TensorDistAttr dist_attr(
phi::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = static_cast<phi::DenseTensor*>(x->impl().get());
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(*dense_t, dist_attr));
}
}

void DistTensorConverter::operator()(Tensor* x) {
DistTensorConverter::convert(x);
}

void DistTensorConverter::operator()(paddle::optional<Tensor>* x) {
if (*x) {
DistTensorConverter::convert(x->get_ptr());
}
}

void DistTensorConverter::operator()(std::vector<Tensor>* x) {
if (!x->empty()) {
for (auto& t : *x) {
DistTensorConverter::convert(&t);
}
}
}

void DistTensorConverter::operator()(paddle::optional<std::vector<Tensor>>* x) {
if (*x) {
if (!(x->get_ptr()->empty())) {
for (auto& t : *(x->get_ptr())) {
if (!t.is_dist_tensor()) {
DistTensorConverter::convert(&t);
}
}
}
}
}

} // namespace pybind
} // namespace paddle
68 changes: 68 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/pir/core/op_result.h"
#include "paddle/utils/pybind.h"
Expand Down Expand Up @@ -384,5 +387,70 @@ class eager_gil_scoped_release {
PyThreadState* tstate{nullptr};
};

/* ------------------ for auto parallel ----------------------- */
using paddle::experimental::detail::ArgsIterator;

struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = false;
const phi::distributed::ProcessMesh** mesh = nullptr;

explicit DistTensorTypeParser(const phi::distributed::ProcessMesh** m)
: mesh(m) {}

bool short_circuit() { return result; }

void operator()(const Tensor& x);
void operator()(const paddle::optional<Tensor>& x);
void operator()(const std::vector<Tensor>& x);
void operator()(const paddle::optional<std::vector<Tensor>>& x);

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};

struct DistTensorConverter : ArgsIterator<DistTensorConverter> {
const phi::distributed::ProcessMesh* mesh = nullptr;

explicit DistTensorConverter(const phi::distributed::ProcessMesh* m) {
PADDLE_ENFORCE_NE(
m,
nullptr,
platform::errors::InvalidArgument(
"Input mesh of DistTensorConverter() shouldn't be nullptr."));
mesh = m;
}

void convert(Tensor* x);
void operator()(Tensor* x);
void operator()(paddle::optional<Tensor>* x);
void operator()(std::vector<Tensor>* x);
void operator()(paddle::optional<std::vector<Tensor>>* x);

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};

template <typename... Args>
bool InputsContainDistTensor(const phi::distributed::ProcessMesh** mesh,
const Args&... args) {
return DistTensorTypeParser(mesh).apply(args...).result;
}

template <typename... Args>
void ConvertAllInputsToDistTensor(const phi::distributed::ProcessMesh* mesh,
Args&... args) {
PADDLE_ENFORCE_NE(
mesh,
nullptr,
platform::errors::InvalidArgument("Input mesh should not be nullptr."));
DistTensorConverter(mesh).apply(&args...);
}

} // namespace pybind
} // namespace paddle
12 changes: 7 additions & 5 deletions paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,22 @@ struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
/* ------------------ for auto parallel ----------------------- */

struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = true;
bool result = false;

void operator()(const Tensor& x) { result &= x.is_dist_tensor(); }
bool short_circuit() { return result; }

void operator()(const Tensor& x) { result = x.is_dist_tensor(); }

void operator()(const paddle::optional<Tensor>& x) {
if (x) {
result &= x.get_ptr()->is_dist_tensor();
result = x.get_ptr()->is_dist_tensor();
}
}

void operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
result &= t.is_dist_tensor();
result = t.is_dist_tensor();
}
}
}
Expand All @@ -193,7 +195,7 @@ struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
result &= t.is_dist_tensor();
result = t.is_dist_tensor();
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions test/auto_parallel/test_api_dist_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,41 @@ def create_local_and_dist_tensor_list_pair(self, np_array_list):
dist_t_list.append(dist_t)
return local_t_list, dist_t_list

def create_two_local_tensor_pair(self, np_array):
if np_array.dtype == np.float32:
local_t_1 = paddle.to_tensor(np_array, dtype='float32')
local_t_2 = paddle.to_tensor(np_array, dtype='float32')
elif np_array.dtype == np.float16:
local_t_1 = paddle.to_tensor(np_array, dtype='float16')
local_t_2 = paddle.to_tensor(np_array, dtype='float16')
elif np_array.dtype == np.int32:
local_t_1 = paddle.to_tensor(np_array, dtype='int32')
local_t_2 = paddle.to_tensor(np_array, dtype='int32')
elif np_array.dtype == np.bool_:
local_t_1 = paddle.to_tensor(np_array, dtype='bool')
local_t_2 = paddle.to_tensor(np_array, dtype='bool')

local_t_1.stop_gradient = False
local_t_2.stop_gradient = False

return local_t_1, local_t_2

# mixed type of inputs: DenseTensor and DistTensor
def test_matmul_api_for_mixed_inputs_type(self):
x = np.random.random(size=[4, 4]).astype("float32")
y = np.random.random(size=[4, 4]).astype("float32")
local_x, dist_x = self.create_local_and_dist_tensor_pair(x)
local_y_1, local_y_2 = self.create_two_local_tensor_pair(y)
local_out = paddle.matmul(local_x, local_y_1)
dist_out = paddle.matmul(dist_x, local_y_2)
self.check_tensor_eq(local_out, dist_out)

# test backward
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_x.grad, dist_x.grad)
self.check_tensor_eq(local_y_1.grad, local_y_2.grad)

# input: std::vector<phi::Tensor>
# output: phi::Tensor
def test_concat_for_dist_tensor(self):
Expand Down

0 comments on commit 19c7af2

Please sign in to comment.