Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type promotion logic for eager between tensor and tensor #59518

Merged
merged 23 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/fluid/eager/eager_layout_auto_tune.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/eager/type_promotion_utils.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/phi/common/type_promotion.h"
#include "paddle/phi/core/flags.h"

PHI_DECLARE_bool(check_nan_inf);
Expand Down Expand Up @@ -56,6 +58,20 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type protmotion automatically.";
LOG(WARNING) << "got different data type, run type protmotion "
"automatically, this may cause data type been changed.";
auto op_name = phi::TransToFluidOpName("multiply");
auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype());

auto new_x = egr::PromoteCast("x", x, promotion_type);
auto new_y = egr::PromoteCast("y", y, promotion_type);

return multiply_ad_func(new_x, new_y);
}

// Layout autotune

if (egr::Controller::Instance().UseLayoutAutoTune()) {
Expand Down Expand Up @@ -388,6 +404,20 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type protmotion automatically.";
LOG(WARNING) << "got different data type, run type protmotion "
"automatically, this may cause data type been changed.";
auto op_name = phi::TransToFluidOpName("multiply");
auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype());

auto new_x = egr::PromoteCast("x", x, promotion_type);
auto new_y = egr::PromoteCast("y", y, promotion_type);

return multiply_ad_func(new_x, new_y);
}

// Layout autotune

if (egr::Controller::Instance().UseLayoutAutoTune()) {
Expand Down
59 changes: 58 additions & 1 deletion paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
"tanh_triple_grad",
]

# white ops list whose kernel can automaically do type promotion.
# future will get this list from same place with static graph.
type_promote_white_list = {
"add": ["x", "y"],
"subtract": ["x", "y"],
"where": ["x", "y"],
}

# dict of special api that forward api's output will affect bacward api's output
# bacward api's output usually affected by backward api's input
special_prune_dict = {
Expand Down Expand Up @@ -247,6 +255,8 @@ class {} : public egr::GradNodeBase {{
// Dygraph Record Event
{}
// AMP Logic
{}
// Type promotion Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -315,6 +325,8 @@ class {} : public egr::GradNodeBase {{
// Dygraph Record Event
{}
// AMP Logic
{}
// Type promotion Logic
{}
// Layout autotune
{}
Expand Down Expand Up @@ -447,7 +459,8 @@ class {} : public egr::GradNodeBase {{
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/api/lib/data_transform.h"

#include "paddle/fluid/eager/type_promotion_utils.h"
#include "paddle/phi/common/type_promotion.h"
PHI_DECLARE_bool(check_nan_inf);
PHI_DECLARE_string(tensor_operants_mode);
{}
Expand Down Expand Up @@ -512,6 +525,21 @@ class {} : public egr::GradNodeBase {{
}}
}}
"""

TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({x}.dtype(), {y}.dtype())) {{
VLOG(5) << "got different data type, run type protmotion automatically.";
LOG(WARNING) << "got different data type, run type protmotion automatically, this may cause data type been changed.";
{op_name}
auto promotion_type = phi::GetPromoteDtype(op_name, {x}.dtype(), {y}.dtype());

auto new_{x} = egr::PromoteCast("{x}", {x}, promotion_type);
auto new_{y} = egr::PromoteCast("{y}", {y}, promotion_type);

{return_value}
}}
"""


LAYOUT_LOGIC_TEMPLATE = """
if (egr::Controller::Instance().UseLayoutAutoTune()) {{
paddle::small_vector<std::vector<paddle::Tensor>, egr::kSlotSmallVectorSize> tensors_vector = {};
Expand Down Expand Up @@ -1459,6 +1487,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_call_list = ["" for i in range(num_inputs)]

amp_inputs_call_list = ["" for i in range(num_inputs)]
type_promote_inputs_call_list = ["" for i in range(num_inputs)]
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
Expand All @@ -1470,6 +1499,11 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
inputs_call_list[pos] = f"{name}"
amp_inputs_call_list[pos] = f"new_{name}"
is_optional = name in optional_inputs
if forward_api_name in type_promote_white_list:
if name in type_promote_white_list[forward_api_name]:
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1804,7 +1838,28 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
amp_autocast_list_str,
amp_call_str,
)
# Forward type promotion logic
if forward_api_name in type_promote_white_list:
# only support two inputs
x = type_promote_white_list[forward_api_name][0]
y = type_promote_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
x=x,
y=y,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
else:
type_promotion_logic_str = (
"\n VLOG(5) << \" No Type Promotion for {} api. \"; ".format(
forward_ad_function_name
)
)
# Forward layout autotune
layout_autotune_list_str = " ".join(
layout_autotune_list
Expand Down Expand Up @@ -1849,6 +1904,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
forward_api_name,
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
layout_logic_str,
forward_api_name,
before_log_str,
Expand All @@ -1871,6 +1927,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
forward_api_name,
dygraph_event_str,
amp_logic_str,
type_promotion_logic_str,
layout_logic_str,
inputs_autograd_meta_str,
forward_api_name,
Expand Down
87 changes: 9 additions & 78 deletions paddle/fluid/eager/type_promotion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"

#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/phi/common/type_promotion.h"

namespace egr {

inline int DataTypeToNum(const phi::DataType& dtype) {
switch (dtype) {
case phi::DataType::UINT8:
return 0;
case phi::DataType::INT8:
return 1;
case phi::DataType::INT16:
return 2;
case phi::DataType::INT32:
return 3;
case phi::DataType::INT64:
return 4;
case phi::DataType::FLOAT16:
return 5;
case phi::DataType::FLOAT32:
return 6;
case phi::DataType::FLOAT64:
return 7;
case phi::DataType::COMPLEX64:
return 8;
case phi::DataType::COMPLEX128:
return 9;
case phi::DataType::BOOL:
return 10;
case phi::DataType::BFLOAT16:
return 11;
default:
PD_THROW("Invalid enum data type for type promote `", dtype, "`.");
}
}

static inline bool is_support_float(phi::DataType dtype) {
if (dtype == phi::DataType::FLOAT16 || dtype == phi::DataType::FLOAT32 ||
dtype == phi::DataType::FLOAT64 || dtype == phi::DataType::BFLOAT16) {
return true;
} else {
return false;
}
}

static inline bool is_support_int(phi::DataType dtype) {
if (dtype == phi::DataType::INT32 || dtype == phi::DataType::INT64) {
return true;
inline paddle::Tensor PromoteCast(const std::string& input_name,
const paddle::Tensor& input,
const phi::DataType& dst_dtype,
bool trace_backward = true) {
if (input.dtype() != dst_dtype) {
return Cast(input, dst_dtype, trace_backward);
} else {
return false;
return input;
}
}

inline static phi::DataType promoteTypes(phi::DataType a, phi::DataType b) {
constexpr auto u1 = phi::DataType::UINT8;
constexpr auto i1 = phi::DataType::INT8;
constexpr auto i2 = phi::DataType::INT16;
constexpr auto i4 = phi::DataType::INT32;
constexpr auto i8 = phi::DataType::INT64;
constexpr auto f2 = phi::DataType::FLOAT16;
constexpr auto f4 = phi::DataType::FLOAT32;
constexpr auto f8 = phi::DataType::FLOAT64;
constexpr auto c4 = phi::DataType::COMPLEX64;
constexpr auto c8 = phi::DataType::COMPLEX128;
constexpr auto b1 = phi::DataType::BOOL;
constexpr auto bf = phi::DataType::BFLOAT16;

static constexpr phi::DataType _promoteTypesLookup[12][12] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 c4 c8 b1 bf*/
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c4, c8, u1, bf},
/* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c4, c8, i1, bf},
/* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c4, c8, i2, bf},
/* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c4, c8, i4, bf},
/* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c4, c8, i8, bf},
/* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c4, c8, f2, f4},
/* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c8, f4, f4},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, f8, f8},
/* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c8, c4, c4},
/* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8},
/* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c4, c8, b1, bf},
/* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c8, bf, bf},
};

return _promoteTypesLookup[DataTypeToNum(a)][DataTypeToNum(b)];
}

} // namespace egr
22 changes: 14 additions & 8 deletions paddle/fluid/pybind/eager_math_op_patch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/common/type_promotion.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -252,10 +253,11 @@ static PyObject* tensor__add__method(TensorObject* self,
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}

// 3. promote types or unify right var type to left var
// 3. promote types or unify right var type to left var, float type promotion
// mv to add_ad_func
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
if (lhs_dtype != rhs_dtype && !phi::NeedTypePromotion(lhs_dtype, rhs_dtype)) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
Expand Down Expand Up @@ -358,10 +360,11 @@ static PyObject* tensor__sub__method(TensorObject* self,
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}

// 3. promote types or unify right var type to left var
// 3. promote types or unify right var type to left var, float type promotion
// mv to subtract_ad_func
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
if (lhs_dtype != rhs_dtype && !phi::NeedTypePromotion(lhs_dtype, rhs_dtype)) {
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
Expand All @@ -386,6 +389,7 @@ static PyObject* tensor__sub__method(TensorObject* self,
other_tensor = cast_ad_func(other_tensor, lhs_dtype);
}
}

// 4. calculation
VLOG(6) << "Calling subtract_ad_func in tensor__sub__method";
{
Expand Down Expand Up @@ -460,10 +464,11 @@ static PyObject* tensor__rsub__method(TensorObject* self,
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}

// 3. promote types or unify right var type to left var
// 3. promote types or unify right var type to left var, float type promotion
// mv to subtract_ad_func
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
if (lhs_dtype != rhs_dtype && !phi::NeedTypePromotion(lhs_dtype, rhs_dtype)) {
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
_complex_dtypes.find(rhs_dtype) != _complex_dtypes.end()) {
phi::DataType promote_dtype =
Expand Down Expand Up @@ -568,10 +573,11 @@ static PyObject* tensor__mul__method(TensorObject* self,
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}

// 3. promote types or unify right var type to left var
// 3. promote types or unify right var type to left var, float type promotion
// mv to multiply_ad_func
phi::DataType lhs_dtype = self_tensor.dtype();
phi::DataType rhs_dtype = other_tensor.dtype();
if (lhs_dtype != rhs_dtype) {
if (lhs_dtype != rhs_dtype && !phi::NeedTypePromotion(lhs_dtype, rhs_dtype)) {
// note: only op_type in _supported_promote_complex_types_ should promote
// dtype
if (_complex_dtypes.find(lhs_dtype) != _complex_dtypes.end() ||
Expand Down
Loading