Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Support operators have mixed inputs. #57774

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def FindParsingFunctionFromAttributeType(atype):
" auto {} = {}(\"{}\", \"{}\", args, {}, {});\n"
)

CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_TEMPLATE = """
const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh{inputs})) {{
ConvertAllInputsToDistTensor(mesh{inputs});
}}
"""

PARSE_PYTHON_C_ARGS_TEMPLATE = """ PyObject* {}_obj = PyTuple_GET_ITEM(args, {});
{} {} = {}({}_obj, \"{}\", {});
Expand Down Expand Up @@ -325,7 +331,9 @@ def GeneratePythonCFunction(self):
inplace_returns_pos_map = {}
# Generate Python-C Tensors Parsing Logic
get_eager_tensor_str = ""
input_names = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
input_names = input_names + ", " + name
if forward_inplace_map and name in forward_inplace_map.keys():
inplace_args_pos_map[name] = pos
is_optional = name in optional_inputs
Expand Down Expand Up @@ -375,6 +383,13 @@ def GeneratePythonCFunction(self):
"false",
)
)
# No inputs, skip convert to DistTensor
if len(input_names) > 0:
get_eager_tensor_str += (
CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_TEMPLATE.format(
inputs=input_names
)
)

if forward_inplace_map:
for name, (ttype, pos) in forward_outputs_position_map.items():
Expand Down
115 changes: 115 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2000,5 +2000,120 @@ void* UnPackHook::operator()(void* packed_value, void* other) {
return reinterpret_cast<void*>(ret);
}

/* ------------------ for auto parallel ----------------------- */

void DistTensorTypeParser::operator()(const Tensor& x) {
if (x.is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(x.impl())
->dist_attr()
.process_mesh());
result = true;
}
}

void DistTensorTypeParser::operator()(const paddle::optional<Tensor>& x) {
if (x) {
if (x.get_ptr()->is_dist_tensor()) {
*mesh = &(std::dynamic_pointer_cast<phi::distributed::DistTensor>(
x.get_ptr()->impl())
->dist_attr()
.process_mesh());
result = true;
}
}
}

void DistTensorTypeParser::operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
if (t.is_dist_tensor()) {
*mesh =
&(std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
result = true;
break;
}
}
}
}

void DistTensorTypeParser::operator()(
const paddle::optional<std::vector<Tensor>>& x) {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
if (t.is_dist_tensor()) {
*mesh = &(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(t.impl())
->dist_attr()
.process_mesh());
result = true;
break;
}
}
}
}
}

void DistTensorConverter::convert(Tensor* x) {
if (x->is_dist_tensor()) {
PADDLE_ENFORCE_EQ(
std::dynamic_pointer_cast<phi::distributed::DistTensor>(x->impl())
->dist_attr()
.process_mesh(),
*mesh,
platform::errors::InvalidArgument(
"Input %s has different mesh. However all inputs should "
"have the same mesh.",
x->name()));
return;
} else {
PADDLE_ENFORCE_EQ(
phi::DenseTensor::classof(x->impl().get()),
true,
platform::errors::InvalidArgument(
"Failed to convert input %s impl to phi::distributed::DistTensor "
"as it's not phi::DenseTensor.",
x->name()));
phi::distributed::TensorDistAttr dist_attr(
phi::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = static_cast<phi::DenseTensor*>(x->impl().get());
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(*dense_t, dist_attr));
}
}

void DistTensorConverter::operator()(Tensor* x) {
DistTensorConverter::convert(x);
}

void DistTensorConverter::operator()(paddle::optional<Tensor>* x) {
if (*x) {
DistTensorConverter::convert(x->get_ptr());
}
}

void DistTensorConverter::operator()(std::vector<Tensor>* x) {
if (!x->empty()) {
for (auto& t : *x) {
DistTensorConverter::convert(&t);
}
}
}

void DistTensorConverter::operator()(paddle::optional<std::vector<Tensor>>* x) {
if (*x) {
if (!(x->get_ptr()->empty())) {
for (auto& t : *(x->get_ptr())) {
if (!t.is_dist_tensor()) {
DistTensorConverter::convert(&t);
}
}
}
}
}

} // namespace pybind
} // namespace paddle
68 changes: 68 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/pir/core/op_result.h"
#include "paddle/utils/pybind.h"
Expand Down Expand Up @@ -384,5 +387,70 @@ class eager_gil_scoped_release {
PyThreadState* tstate{nullptr};
};

/* ------------------ for auto parallel ----------------------- */
using paddle::experimental::detail::ArgsIterator;

struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = false;
const phi::distributed::ProcessMesh** mesh = nullptr;

explicit DistTensorTypeParser(const phi::distributed::ProcessMesh** m)
: mesh(m) {}

bool short_circuit() { return result; }

void operator()(const Tensor& x);
void operator()(const paddle::optional<Tensor>& x);
void operator()(const std::vector<Tensor>& x);
void operator()(const paddle::optional<std::vector<Tensor>>& x);

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};

struct DistTensorConverter : ArgsIterator<DistTensorConverter> {
const phi::distributed::ProcessMesh* mesh = nullptr;

explicit DistTensorConverter(const phi::distributed::ProcessMesh* m) {
PADDLE_ENFORCE_NE(
m,
nullptr,
platform::errors::InvalidArgument(
"Input mesh of DistTensorConverter() shouldn't be nullptr."));
mesh = m;
}

void convert(Tensor* x);
void operator()(Tensor* x);
void operator()(paddle::optional<Tensor>* x);
void operator()(std::vector<Tensor>* x);
void operator()(paddle::optional<std::vector<Tensor>>* x);

// skip other type args, these args don't used in kernel selection
template <typename T>
void operator()(const T& x) {
// do nothing
}
};

template <typename... Args>
bool InputsContainDistTensor(const phi::distributed::ProcessMesh** mesh,
const Args&... args) {
return DistTensorTypeParser(mesh).apply(args...).result;
}

template <typename... Args>
void ConvertAllInputsToDistTensor(const phi::distributed::ProcessMesh* mesh,
Args&... args) {
PADDLE_ENFORCE_NE(
mesh,
nullptr,
platform::errors::InvalidArgument("Input mesh should not be nullptr."));
DistTensorConverter(mesh).apply(&args...);
}

} // namespace pybind
} // namespace paddle
12 changes: 7 additions & 5 deletions paddle/phi/api/lib/kernel_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,22 @@ struct KernelTypeParser : ArgsIterator<KernelTypeParser> {
/* ------------------ for auto parallel ----------------------- */

struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
bool result = true;
bool result = false;

void operator()(const Tensor& x) { result &= x.is_dist_tensor(); }
bool short_circuit() { return result; }

void operator()(const Tensor& x) { result = x.is_dist_tensor(); }

void operator()(const paddle::optional<Tensor>& x) {
if (x) {
result &= x.get_ptr()->is_dist_tensor();
result = x.get_ptr()->is_dist_tensor();
}
}

void operator()(const std::vector<Tensor>& x) {
if (!x.empty()) {
for (auto& t : x) {
result &= t.is_dist_tensor();
result = t.is_dist_tensor();
}
}
}
Expand All @@ -193,7 +195,7 @@ struct DistTensorTypeParser : ArgsIterator<DistTensorTypeParser> {
if (x) {
if (!(x.get_ptr()->empty())) {
for (auto& t : *(x.get_ptr())) {
result &= t.is_dist_tensor();
result = t.is_dist_tensor();
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions test/auto_parallel/test_api_dist_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,41 @@ def create_local_and_dist_tensor_list_pair(self, np_array_list):
dist_t_list.append(dist_t)
return local_t_list, dist_t_list

def create_two_local_tensor_pair(self, np_array):
if np_array.dtype == np.float32:
local_t_1 = paddle.to_tensor(np_array, dtype='float32')
local_t_2 = paddle.to_tensor(np_array, dtype='float32')
elif np_array.dtype == np.float16:
local_t_1 = paddle.to_tensor(np_array, dtype='float16')
local_t_2 = paddle.to_tensor(np_array, dtype='float16')
elif np_array.dtype == np.int32:
local_t_1 = paddle.to_tensor(np_array, dtype='int32')
local_t_2 = paddle.to_tensor(np_array, dtype='int32')
elif np_array.dtype == np.bool_:
local_t_1 = paddle.to_tensor(np_array, dtype='bool')
local_t_2 = paddle.to_tensor(np_array, dtype='bool')

local_t_1.stop_gradient = False
local_t_2.stop_gradient = False

return local_t_1, local_t_2

# mixed type of inputs: DenseTensor and DistTensor
def test_matmul_api_for_mixed_inputs_type(self):
x = np.random.random(size=[4, 4]).astype("float32")
y = np.random.random(size=[4, 4]).astype("float32")
local_x, dist_x = self.create_local_and_dist_tensor_pair(x)
local_y_1, local_y_2 = self.create_two_local_tensor_pair(y)
local_out = paddle.matmul(local_x, local_y_1)
dist_out = paddle.matmul(dist_x, local_y_2)
self.check_tensor_eq(local_out, dist_out)

# test backward
local_out.backward()
dist_out.backward()
self.check_tensor_eq(local_x.grad, dist_x.grad)
self.check_tensor_eq(local_y_1.grad, local_y_2.grad)

# input: std::vector<phi::Tensor>
# output: phi::Tensor
def test_concat_for_dist_tensor(self):
Expand Down