Skip to content

Commit

Permalink
Tensor type promotion for static mode (#59763)
Browse files Browse the repository at this point in the history
* add type promotion table.

* fix codestyle.

* add python table.

* fix dtype.

* remove useless note

* fix static-check

* add eager T+T logic.

* remove useless file.

* remove useless line.

* fix

* dtype promotion for operator overload in static mode

* only support float series

* update

* fix note.

* mv common logic to common dir.

* fix

* remove deal for int.

* remove int.

* only for complie

* fix median / cross_entropy_loss

* keep old illogical logic for compatibility reasons

* pybind the type_promotion function; remove python function; remove float-complex test

* remove change for dygraph

* rename type_promotion_table.h -> data_type_promotion.h

* warnings

---------

Co-authored-by: zxcd <228587199@qq.com>
  • Loading branch information
zoooo0820 and zxcd authored Dec 8, 2023
1 parent cc07f54 commit 6e09e44
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 73 deletions.
17 changes: 17 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ limitations under the License. */
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/api/include/operants_manager.h"
#include "paddle/phi/api/include/tensor_operants.h"
#include "paddle/phi/common/type_promotion.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
Expand Down Expand Up @@ -883,6 +884,22 @@ PYBIND11_MODULE(libpaddle, m) {
&paddle::prim::PrimCommonUtils::SetTargetGradName);
m.def("set_num_threads", &platform::SetNumThreads);

m.def("need_type_promotion",
[](framework::proto::VarType::Type type_x,
framework::proto::VarType::Type type_y) {
return phi::NeedTypePromotion(framework::TransToPhiDataType(type_x),
framework::TransToPhiDataType(type_y));
});
m.def("get_promote_dtype",
[](const std::string &op_name,
framework::proto::VarType::Type type_x,
framework::proto::VarType::Type type_y) {
return framework::TransToProtoVarType(
phi::GetPromoteDtype(op_name,
framework::TransToPhiDataType(type_x),
framework::TransToPhiDataType(type_y)));
});

m.def("disable_signal_handler", &DisableSignalHandler);

m.def("clear_gradients",
Expand Down
1 change: 0 additions & 1 deletion python/paddle/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@
HeterXpuTrainer,
)
from .backward import append_backward
from . import type_promotion

Tensor = LoDTensor
enable_imperative = enable_dygraph
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ def to_list(s):
_set_prim_target_grad_name,
)

# type promotion
from .libpaddle import need_type_promotion, get_promote_dtype # noqa: F401

# isort: on
if sys.platform != 'win32':
from .libpaddle import ( # noqa: F401
Expand Down
34 changes: 32 additions & 2 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@

compare_ops = ['__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__']

SUPPORT_PROMOTION_OPS = [
"__add__",
"__radd__",
"__sub__",
"__rsub__",
"__mul__",
"__rmul__",
]

EXPRESSION_MAP = {
"__add__": "A + B",
"__radd__": "A += B",
Expand Down Expand Up @@ -519,10 +528,31 @@ def __impl__(self, other_var):
current_block(self), value=other_var, dtype=lhs_dtype
)

# 3. unify right var type to left var
# 3. type promotion
rhs_dtype = safe_get_dtype(other_var)

if lhs_dtype != rhs_dtype:
other_var = astype(other_var, lhs_dtype)
if method_name in SUPPORT_PROMOTION_OPS:
if core.need_type_promotion(lhs_dtype, rhs_dtype):
common_dtype = core.get_promote_dtype(
op_type, lhs_dtype, rhs_dtype
)
warnings.warn(
f"The input dtypes of OP {op_type} are {lhs_dtype} and {rhs_dtype}, "
"the output will be auto-promoted to {common_dtype}"
)
if rhs_dtype != common_dtype:
other_var = astype(other_var, common_dtype)
if lhs_dtype != common_dtype:
self = astype(self, common_dtype)
else:
# NOTE(zoooo0820): Currently, we still keep the old illogical \
# logic for compatibility reasons
other_var = astype(other_var, lhs_dtype)

else:
other_var = astype(other_var, lhs_dtype)

if reverse:
tmp = self
self = other_var
Expand Down
70 changes: 0 additions & 70 deletions python/paddle/base/type_promotion.py

This file was deleted.

176 changes: 176 additions & 0 deletions test/legacy_test/test_tensor_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,181 @@ def test_operator(self):
self.div_operator()


def create_test_case(baseclass, ldtype, rdtype, expected_out_dtype=None):
class TestPromotion(baseclass):
def set_dtype(self):
self.ldtype = ldtype
self.rdtype = rdtype
self.expected_out_dtype = expected_out_dtype

cls_name = f"{baseclass.__name__}Between{ldtype}And{rdtype}"
TestPromotion.__name__ = cls_name
globals()[cls_name] = TestPromotion


class TestOperatorOverloadAddInStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.set_dtype()
self.exe = paddle.static.Executor()

def set_dtype(self):
self.ldtype = 'float32'
self.rdtype = 'float64'
self.expected_out_dtype = 'float64'

def generate_test_value(self):
self.l_value = (paddle.randn((4, 3, 2)) * 10).astype(self.ldtype)
self.r_value = (paddle.randn((4, 3, 2)) * 10).astype(self.rdtype)

def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value + self.r_value
out_reverse = self.r_value + self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res

def test_dtype_is_expected(self):
res = self.run_api()
self.assertEqual(res[0].dtype.__str__(), self.expected_out_dtype)
self.assertEqual(res[1].dtype.__str__(), self.expected_out_dtype)


create_test_case(
TestOperatorOverloadAddInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadAddInStatic, 'float16', 'float64', 'float64'
)

create_test_case(
TestOperatorOverloadAddInStatic, 'float32', 'float64', 'float64'
)


if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadAddInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadAddInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadAddInStatic, 'bfloat16', 'float64', 'float64'
)


class TestOperatorOverloadSubInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value - self.r_value
out_reverse = self.r_value - self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(
TestOperatorOverloadSubInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadSubInStatic, 'float16', 'float64', 'float64'
)

create_test_case(
TestOperatorOverloadSubInStatic, 'float32', 'float64', 'float64'
)


if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadSubInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadSubInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadSubInStatic, 'bfloat16', 'float64', 'float64'
)


class TestOperatorOverloadMulInStatic(TestOperatorOverloadAddInStatic):
def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value * self.r_value
out_reverse = self.r_value * self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'float16', 'float64', 'float64'
)

create_test_case(
TestOperatorOverloadMulInStatic, 'float32', 'float64', 'float64'
)

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float16', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float32', 'float32'
)
create_test_case(
TestOperatorOverloadMulInStatic, 'bfloat16', 'float64', 'float64'
)


class TestOperatorOverloadGTInStatic(TestOperatorOverloadAddInStatic):
def set_dtype(self):
self.ldtype = 'float32'
self.rdtype = 'float64'
self.expected_out_dtype = 'bool'

def run_api(self):
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
self.generate_test_value()

out = self.l_value > self.r_value
out_reverse = self.r_value > self.l_value

res = self.exe.run(prog, fetch_list=[out, out_reverse])
return res


create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float32', 'bool')
create_test_case(TestOperatorOverloadGTInStatic, 'float16', 'float64', 'bool')

create_test_case(TestOperatorOverloadGTInStatic, 'float32', 'float64', 'bool')

if paddle.is_compiled_with_cuda() and paddle.base.core.supports_bfloat16():
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float16', 'bool'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float32', 'bool'
)
create_test_case(
TestOperatorOverloadGTInStatic, 'bfloat16', 'float64', 'bool'
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6e09e44

Please sign in to comment.