diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 4e8026fd92d03..fc2bd9ffea578 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -312,6 +312,15 @@ bool IsCompiledWithCINN() { #endif } +bool IsRunWithCINN() { +#ifndef PADDLE_WITH_CINN + return false; +#else + return framework::paddle2cinn::CinnCompiler::GetInstance() + ->real_compiled_num() > 0; +#endif +} + bool IsCompiledWithHETERPS() { #ifndef PADDLE_WITH_HETERPS return false; @@ -1909,6 +1918,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compiled_with_mpi", IsCompiledWithMPI); m.def("is_compiled_with_mpi_aware", IsCompiledWithMPIAWARE); m.def("is_compiled_with_cinn", IsCompiledWithCINN); + m.def("is_run_with_cinn", IsRunWithCINN); m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index bcfd845d7fcb0..2f3c36bb3468e 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -338,6 +338,10 @@ def to_list(s): from .libpaddle import _cleanup_mmap_fds from .libpaddle import _remove_tensor_list_mmap_fds from .libpaddle import _set_max_memory_map_allocation_pool_size + + # CINN + from .libpaddle import is_run_with_cinn + except Exception as e: if has_paddle_dy_lib: sys.stderr.write( diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 0ef5186dab2d0..d52364462791d 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -64,6 +64,18 @@ def train(self, use_cinn): sgd.clear_grad() res.append(out.numpy()) + + if use_cinn and paddle.device.is_compiled_with_cinn(): + self.assertTrue( + paddle.framework.core.is_run_with_cinn(), + msg="The test was not running with CINN! Please check.", + ) + else: + self.assertFalse( + paddle.framework.core.is_run_with_cinn(), + msg="The test should not running with CINN when the whl package was not compiled with CINN! Please check.", + ) + return res def test_cinn(self):