From 87f6013c45b6044a983971b5e2f044155a10913b Mon Sep 17 00:00:00 2001 From: 6clc Date: Thu, 15 Jun 2023 14:52:32 +0800 Subject: [PATCH 01/14] feat(cmake): add cmake of cinn --- CMakeLists.txt | 41 ++- cmake/cinn.cmake | 299 ++++++++++++++++++++ cmake/cinn/config.cmake | 11 + cmake/cinn/core.cmake | 459 +++++++++++++++++++++++++++++++ cmake/cinn/export.map | 6 + cmake/cinn/external/absl.cmake | 78 ++++++ cmake/cinn/external/boost.cmake | 65 +++++ cmake/cinn/external/ginac.cmake | 36 +++ cmake/cinn/external/isl.cmake | 32 +++ cmake/cinn/external/jitify.cmake | 28 ++ cmake/cinn/external/llvm.cmake | 129 +++++++++ cmake/cinn/external/openmp.cmake | 37 +++ cmake/cinn/llvm.cmake | 86 ++++++ cmake/cinn/nvrtc.cmake | 24 ++ cmake/cinn/nvtx.cmake | 53 ++++ cmake/cinn/system.cmake | 106 +++++++ cmake/cinn/version.cmake | 76 +++++ cmake/external/cinn.cmake | 96 ------- cmake/external/pybind11.cmake | 2 +- cmake/third_party.cmake | 37 ++- python/CMakeLists.txt | 40 +++ python/setup_cinn.py.in | 181 ++++++++++++ 22 files changed, 1803 insertions(+), 119 deletions(-) create mode 100644 cmake/cinn.cmake create mode 100755 cmake/cinn/config.cmake create mode 100644 cmake/cinn/core.cmake create mode 100644 cmake/cinn/export.map create mode 100644 cmake/cinn/external/absl.cmake create mode 100644 cmake/cinn/external/boost.cmake create mode 100644 cmake/cinn/external/ginac.cmake create mode 100644 cmake/cinn/external/isl.cmake create mode 100644 cmake/cinn/external/jitify.cmake create mode 100644 cmake/cinn/external/llvm.cmake create mode 100644 cmake/cinn/external/openmp.cmake create mode 100644 cmake/cinn/llvm.cmake create mode 100644 cmake/cinn/nvrtc.cmake create mode 100644 cmake/cinn/nvtx.cmake create mode 100644 cmake/cinn/system.cmake create mode 100644 cmake/cinn/version.cmake delete mode 100644 cmake/external/cinn.cmake create mode 100644 python/setup_cinn.py.in diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a75d8c35552d..9d354485811f0 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,16 @@ option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME" OFF) option(WITH_CUSPARSELT "Compile PaddlePaddle with CUSPARSELT" OFF) option(WITH_SETUP_INSTALL "Compile PaddlePaddle with setup.py" OFF) option(WITH_SHARED_PHI "Compile PaddlePaddle with SHARED LIB of PHI" OFF) +option(CINN_ONLY "Compile CINN only in Paddle" OFF) + +find_package(Git REQUIRED) + +# config GIT_URL with github mirrors to speed up dependent repos clone +option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL}) +if(NOT GIT_URL) + set(GIT_URL "https://github.com") +endif() + # Note(zhouwei): It use option above, so put here include(init) include(generic) # simplify cmake module @@ -112,7 +122,7 @@ endif() if(WIN32) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) - message("Build static library of PHI") + set(CMAKE_SUPPRESS_REGENERATION ON) set(CMAKE_STATIC_LIBRARY_PREFIX lib) @@ -229,13 +239,6 @@ else() ) endif() -find_package(Git REQUIRED) - -# config GIT_URL with github mirrors to speed up dependent repos clone -option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL}) -if(NOT GIT_URL) - set(GIT_URL "https://github.com") -endif() find_package(Threads REQUIRED) @@ -569,6 +572,28 @@ include(third_party include(flags) # set paddle compile flags +#------------- cinn cmake config start -------------- + +if(WITH_CINN) + message(STATUS "Compile Paddle with CINN.") + include(cmake/cinn.cmake) + add_definitions(-DPADDLE_WITH_CINN) + if(WITH_GPU) + add_definitions(-DCINN_WITH_CUDA) + add_definitions(-DCINN_WITH_CUDNN) + endif() + + if(CINN_ONLY) + if(WITH_PYTHON) + add_subdirectory(python) + endif() + add_subdirectory(test) + return() + endif() +endif() + +#------------- cinn cmake config end -------------- + if(WITH_PROFILER) find_package(Gperftools REQUIRED) include_directories(${GPERFTOOLS_INCLUDE_DIR}) diff --git a/cmake/cinn.cmake b/cmake/cinn.cmake new file mode 100644 index 0000000000000..74fdf7c4ae358 --- /dev/null +++ b/cmake/cinn.cmake @@ -0,0 +1,299 @@ +set(CMAKE_SKIP_INSTALL_ALL_DEPENDENCY TRUE) + +set(CINN_THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(DOWNLOAD_MODEL_DIR "${CINN_THIRD_PARTY_PATH}/model") + +if(NOT DEFINED ENV{runtime_include_dir}) + message( + STATUS + "set runtime_include_dir: ${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/cuda") + set(ENV{runtime_include_dir} "${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/cuda") + add_definitions( + -DRUNTIME_INCLUDE_DIR="${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/cuda") +endif() + +if(WITH_TESTING) + add_definitions(-DCINN_WITH_TEST) +endif() +if(WITH_DEBUG) + add_definitions(-DCINN_WITH_DEBUG) +endif() + + +# TODO(zhhsplendid): CINN has lots of warnings during early development. +# They will be treated as errors under paddle. We set no-error now and we will +# clean the code in the future. +add_definitions(-w) + +include(cmake/cinn/version.cmake) +# include the customized configures +if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) + include(${CMAKE_BINARY_DIR}/config.cmake) +endif() + +if(WITH_GPU) + message(STATUS "Enable CINN CUDA") + add_definitions(-DCINN_WITH_CUDA) + message(STATUS "Enable CINN CUDNN") + add_definitions(-DCINN_WITH_CUDNN) + enable_language(CUDA) + find_package(CUDA REQUIRED) + include_directories(${CUDA_INCLUDE_DIRS}) + include_directories(${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/cuda) + include_directories(/usr/lib/x86_64-linux-gnu) + set(CUDA_SEPARABLE_COMPILATION ON) + + cuda_select_nvcc_arch_flags(ARCH_FLAGS Auto) + list(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) + + message( + STATUS + "copy paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h to $ENV{runtime_include_dir}" + ) + file(COPY paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h + DESTINATION $ENV{runtime_include_dir}) + + find_library(CUDASTUB libcuda.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/ + REQUIRED) + find_library(CUBLAS libcublas.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + /usr/lib REQUIRED) + find_library(CUDNN libcudnn.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 /usr/lib + REQUIRED) + find_library(CURAND libcurand.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + /usr/lib REQUIRED) + find_library(CUSOLVER libcusolver.so HINTS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + /usr/lib REQUIRED) +endif() + +set(cinnapi_src CACHE INTERNAL "" FORCE) +set(core_src CACHE INTERNAL "" FORCE) +set(core_includes CACHE INTERNAL "" FORCE) +set(core_proto_includes CACHE INTERNAL "" FORCE) + +include_directories(${CMAKE_SOURCE_DIR}) +include_directories(${CMAKE_BINARY_DIR}) + +include(cmake/generic.cmake) +include(cmake/cinn/system.cmake) +include(cmake/cinn/core.cmake) +include(cmake/cinn/external/absl.cmake) +include(cmake/cinn/nvrtc.cmake) +include(cmake/cinn/nvtx.cmake) +include(cmake/cinn/external/llvm.cmake) +include(cmake/cinn/external/isl.cmake) +include(cmake/cinn/external/ginac.cmake) +include(cmake/cinn/external/openmp.cmake) +include(cmake/cinn/external/jitify.cmake) + + +set(LINK_FLAGS + "-Wl,--version-script ${CMAKE_CURRENT_SOURCE_DIR}/cmake/cinn/export.map" + CACHE INTERNAL "") +set(global_test_args + "--cinn_x86_builtin_code_root=${CMAKE_SOURCE_DIR}/paddle/cinn/backends") + +set(Python_VIRTUALENV FIRST) + +if(NOT PYTHON_EXECUTABLE) + find_package(PythonInterp ${PY_VERSION} REQUIRED) +endif() + +if(NOT PYTHON_LIBRARIES) + find_package(PythonLibs ${PY_VERSION} REQUIRED) +endif() + +message(STATUS "PYTHON_LIBRARIES: ${PYTHON_LIBRARIES}") +message(STATUS "PYTHON_INCLUDE_DIR: ${PYTHON_INCLUDE_DIR}") + +include_directories(${PYTHON_INCLUDE_DIR}) + +set(core_deps CACHE INTERNAL "" FORCE) +set(hlir_src CACHE INTERNAL "" FORCE) + +# TODO(chenweihang): The logic later depends adding cinn subdirectory here, +# but better to move to paddle/CMakeLists.txt +add_subdirectory(paddle/cinn) + +set(core_src "${cinnapi_src}") + +cinn_cc_library( + cinnapi + SHARED + SRCS + ${cinnapi_src} + DEPS + glog + ${llvm_libs} + cinn_framework_proto + param_proto + auto_schedule_proto + schedule_desc_proto + absl + isl + ginac + pybind + ${jitify_deps}) +add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) +add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) + +target_link_libraries(cinnapi ${PYTHON_LIBRARIES}) + +if(WITH_MKL) + target_link_libraries(cinnapi cinn_mklml) + add_dependencies(cinnapi cinn_mklml) + if(WITH_MKLDNN) + target_link_libraries(cinnapi mkldnn) + add_dependencies(cinnapi mkldnn) + endif() +endif() + +if(WITH_GPU) + target_link_libraries( + cinnapi + ${CUDA_NVRTC_LIB} + ${CUDA_LIBRARIES} + ${CUDASTUB} + ${CUBLAS} + ${CUDNN} + ${CURAND} + ${CUSOLVER}) + if(NVTX_FOUND) + target_link_libraries(cinnapi ${CUDA_NVTX_LIB}) + endif() +endif() + +function(gen_cinncore LINKTYPE) + set(CINNCORE_TARGET cinncore) + if(${LINKTYPE} STREQUAL "STATIC") + set(CINNCORE_TARGET cinncore_static) + endif() + cinn_cc_library( + ${CINNCORE_TARGET} + ${LINKTYPE} + SRCS + ${core_src} + DEPS + glog + ${llvm_libs} + cinn_framework_proto + param_proto + auto_schedule_proto + schedule_desc_proto + absl + isl + ginac) + add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) + add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) + + add_dependencies(${CINNCORE_TARGET} pybind) + target_link_libraries(${CINNCORE_TARGET} ${PYTHON_LIBRARIES}) + + if(WITH_MKL) + target_link_libraries(${CINNCORE_TARGET} cinn_mklml) + add_dependencies(${CINNCORE_TARGET} cinn_mklml) + if(WITH_MKLDNN) + target_link_libraries(${CINNCORE_TARGET} mkldnn) + add_dependencies(${CINNCORE_TARGET} mkldnn) + endif() + endif() + + if(WITH_GPU) + target_link_libraries( + ${CINNCORE_TARGET} + ${CUDA_NVRTC_LIB} + ${CUDA_LIBRARIES} + ${CUDASTUB} + ${CUBLAS} + ${CUDNN} + ${CURAND} + ${CUSOLVER} + ${jitify_deps}) + if(NVTX_FOUND) + target_link_libraries(${CINNCORE_TARGET} ${CUDA_NVTX_LIB}) + endif() + endif() +endfunction() + +gen_cinncore(STATIC) +gen_cinncore(SHARED) + +# --------distribute cinncore lib and include begin-------- +set(PUBLISH_LIBS ON) +if(PUBLISH_LIBS) + set(core_includes + "${core_includes};paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh") + foreach(header ${core_includes}) + get_filename_component(prefix ${header} DIRECTORY) + file(COPY ${header} + DESTINATION ${CMAKE_BINARY_DIR}/dist/cinn/include/${prefix}) + endforeach() + + foreach(proto_header ${core_proto_includes}) + string(REPLACE ${CMAKE_BINARY_DIR}/ "" relname ${proto_header}) + get_filename_component(prefix ${relname} DIRECTORY) + set(target_name ${CMAKE_BINARY_DIR}/dist/cinn/include/${relname}) + add_custom_command( + TARGET cinnapi + POST_BUILD + COMMENT "copy generated proto header '${relname}' to dist" + COMMAND cmake -E copy ${proto_header} ${target_name} DEPENDS cinnapi) + endforeach() + + add_custom_command( + TARGET cinnapi + POST_BUILD + COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinnapi.so + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinnapi.so + COMMAND cmake -E copy_directory ${CINN_THIRD_PARTY_PATH}/install + ${CMAKE_BINARY_DIR}/dist/third_party DEPENDS cinnapi) + add_custom_command( + TARGET cinncore_static + POST_BUILD + COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinncore_static.a + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinncore_static.a + COMMAND + cmake -E copy + ${CMAKE_BINARY_DIR}/paddle/cinn/frontend/paddle/libcinn_framework_proto.a + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinn_framework_proto.a + COMMAND + cmake -E copy ${CMAKE_BINARY_DIR}/paddle/cinn/hlir/pe/libparam_proto.a + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libparam_proto.a + COMMAND + cmake -E copy + ${CMAKE_BINARY_DIR}/paddle/cinn/auto_schedule/libauto_schedule_proto.a + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libauto_schedule_proto.a + COMMAND + cmake -E copy ${CMAKE_BINARY_DIR}/paddle/cinn/ir/libschedule_desc_proto.a + ${CMAKE_BINARY_DIR}/dist/cinn/lib/libschedule_desc_proto.a + COMMENT "distribute libcinncore_static.a and related header files." DEPENDS + cinncore_static) +endif() +# --------distribute cinncore lib and include end-------- + +set(CINN_LIB_NAME "libcinnapi.so") +set(CINN_LIB_LOCATION "${CMAKE_BINARY_DIR}/dist/cinn/lib") +set(CINN_LIB "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}") + +###################################### +# Add CINN's dependencies header files +###################################### + +# Add absl +set(ABSL_INCLUDE_DIR "${CMAKE_BINARY_DIR}/dist/third_party/absl/include") +include_directories(${ABSL_INCLUDE_DIR}) + +# Add isl +set(ISL_INCLUDE_DIR "${CMAKE_BINARY_DIR}/dist/third_party/isl/include") +include_directories(${ISL_INCLUDE_DIR}) + +# Add LLVM +set(LLVM_INCLUDE_DIR "${CMAKE_BINARY_DIR}/dist/third_party/llvm/include") +include_directories(${LLVM_INCLUDE_DIR}) + +###################################################### +# Put external_cinn and dependencies together as a lib +###################################################### + +set(CINN_INCLUDE_DIR "${CMAKE_BINARY_DIR}/dist/cinn/include") +include_directories(${CINN_INCLUDE_DIR}) diff --git a/cmake/cinn/config.cmake b/cmake/cinn/config.cmake new file mode 100755 index 0000000000000..4a390539fabef --- /dev/null +++ b/cmake/cinn/config.cmake @@ -0,0 +1,11 @@ +# The home path of ISL +# Required! +set(ISL_HOME "") + +# Whether enable NVidia CUDA support. +# Possible values: ON, OFF +set(WITH_GPU ON) + +set(WITH_MKL ON) +set(WITH_MKLDNN ON) +set(USE_OPENMP "intel") diff --git a/cmake/cinn/core.cmake b/cmake/cinn/core.cmake new file mode 100644 index 0000000000000..91809b697aeec --- /dev/null +++ b/cmake/cinn/core.cmake @@ -0,0 +1,459 @@ +set(CMAKE_CXX_FLAGS + "${CMAKE_CXX_FLAGS} -fPIC -mavx -mfma -Wno-write-strings -Wno-psabi") + +set(PADDLE_RESOURCE_URL + "http://paddle-inference-dist.bj.bcebos.com" + CACHE STRING "inference download url") + +function(cinn_cc_library TARGET_NAME) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cinn_cc_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + if(cinn_cc_library_SRCS) + if(cinn_cc_library_SHARED OR cinn_cc_library_shared) # build *.so + add_library(${TARGET_NAME} SHARED ${cinn_cc_library_SRCS}) + else() + add_library(${TARGET_NAME} STATIC ${cinn_cc_library_SRCS}) + endif() + + if(cinn_cc_library_DEPS) + # Don't need link libwarpctc.so + target_link_libraries(${TARGET_NAME} ${cinn_cc_library_DEPS}) + add_dependencies(${TARGET_NAME} ${cinn_cc_library_DEPS}) + endif() + + # cpplint code style + foreach(source_file ${cinn_cc_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cinn_cc_library_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + else(cinn_cc_library_SRCS) + if(cinn_cc_library_DEPS) + cinn_merge_static_libs(${TARGET_NAME} ${cinn_cc_library_DEPS}) + else() + message( + FATAL_ERROR + "Please specify source files or libraries in cinn_cc_library(${TARGET_NAME} ...)." + ) + endif() + endif(cinn_cc_library_SRCS) + + if((NOT ("${TARGET_NAME}" STREQUAL "cinn_gtest_main")) + AND (NOT ("${TARGET_NAME}" STREQUAL "utils")) + AND (NOT ("${TARGET_NAME}" STREQUAL "lib"))) + target_link_libraries(${TARGET_NAME} Threads::Threads) + + endif( + (NOT ("${TARGET_NAME}" STREQUAL "cinn_gtest_main")) + AND (NOT ("${TARGET_NAME}" STREQUAL "utils")) + AND (NOT ("${TARGET_NAME}" STREQUAL "lib"))) +endfunction(cinn_cc_library) + +list(APPEND CMAKE_CTEST_ARGUMENTS) + +function(remove_gflags TARGET_NAME) + get_target_property(TARGET_LIBRARIES ${TARGET_NAME} LINK_LIBRARIES) + list(REMOVE_ITEM TARGET_LIBRARIES glog) + list(REMOVE_ITEM TARGET_LIBRARIES gflags) + set_property(TARGET ${TARGET_NAME} PROPERTY LINK_LIBRARIES + ${TARGET_LIBRARIES}) +endfunction() + +function(cinn_cc_test TARGET_NAME) + if(WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cinn_cc_test "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cinn_cc_test_SRCS}) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${os_dependency_modules} + cinn_gtest_main gtest glog ${cinn_cc_test_DEPS}) + add_dependencies(${TARGET_NAME} cinn_gtest_main gtest glog + ${cinn_cc_test_DEPS}) + + add_test( + NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} "${cinn_cc_test_ARGS}" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if(${cinn_cc_test_SERIAL}) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + # No unit test should exceed 10 minutes. + set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 6000) + remove_gflags(${TARGET_NAME}) + endif() +endfunction() + +function(cinn_nv_library TARGET_NAME) + if(WITH_GPU) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cinn_nv_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + if(cinn_nv_library_SRCS) + if(cinn_nv_library_SHARED OR cinn_nv_library_shared) # build *.so + cuda_add_library(${TARGET_NAME} SHARED ${cinn_nv_library_SRCS}) + else() + cuda_add_library(${TARGET_NAME} STATIC ${cinn_nv_library_SRCS}) + endif() + if(cinn_nv_library_DEPS) + add_dependencies(${TARGET_NAME} ${cinn_nv_library_DEPS}) + target_link_libraries(${TARGET_NAME} ${cinn_nv_library_DEPS}) + endif() + # cpplint code style + foreach(source_file ${cinn_nv_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cinn_nv_library_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + else(cinn_nv_library_SRCS) + if(cinn_nv_library_DEPS) + cinn_merge_static_libs(${TARGET_NAME} ${cinn_nv_library_DEPS}) + else() + message(FATAL + "Please specify source file or library in cinn_nv_library.") + endif() + endif(cinn_nv_library_SRCS) + target_link_libraries(${TARGET_NAME} Threads::Threads) + endif() +endfunction(cinn_nv_library) + +function(cinn_nv_binary TARGET_NAME) + if(WITH_GPU) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cinn_nv_binary "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + cuda_add_executable(${TARGET_NAME} ${cinn_nv_binary_SRCS}) + if(cinn_nv_binary_DEPS) + target_link_libraries(${TARGET_NAME} ${cinn_nv_binary_DEPS}) + add_dependencies(${TARGET_NAME} ${cinn_nv_binary_DEPS}) + common_link(${TARGET_NAME}) + endif() + endif() +endfunction(cinn_nv_binary) + +function(cinn_nv_test TARGET_NAME) + if(WITH_GPU AND WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cinn_nv_test "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + cuda_add_executable(${TARGET_NAME} ${cinn_nv_test_SRCS}) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries( + ${TARGET_NAME} + ${cinn_nv_test_DEPS} + cinn_gtest_main + gtest + ${os_dependency_modules} + ${CUDNN_LIBRARY} + ${CUBLAS_LIBRARIES} + ${CUDA_LIBRARIES}) + add_dependencies(${TARGET_NAME} ${cinn_nv_test_DEPS} cinn_gtest_main gtest) + common_link(${TARGET_NAME}) + # add_test(${TARGET_NAME} ${TARGET_NAME}) + add_test( + NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} "${cinn_nv_test_ARGS}" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if(cinn_nv_test_SERIAL) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + target_link_libraries( + ${TARGET_NAME} Threads::Threads ${CUDA_NVRTC_LIB} ${CUDA_LIBRARIES} + ${CUDA_cudart_static_LIBRARY} + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs/libcuda.so) + if(NVTX_FOUND) + target_link_libraries(${TARGET_NAME} ${CUDA_NVTX_LIB}) + endif() + remove_gflags(${TARGET_NAME}) + endif() +endfunction(cinn_nv_test) + +# Add dependency that TARGET will depend on test result of DEP, this function executes the DEP during make. +function(add_run_test_dependency TARGET_NAME DEP_NAME) + if(WITH_TESTING) + set(custom_target_name ${TARGET_NAME}_TEST_OUTPUT_DEPENDENCY_ON_${DEP_NAME}) + add_custom_target( + ${custom_target_name} + COMMAND + cd ${CMAKE_CURRENT_BINARY_DIR} && ./${DEP_NAME} + --cinn_x86_builtin_code_root=${CMAKE_SOURCE_DIR}/paddle/cinn/backends + COMMAND cd ${CMAKE_BINARY_DIR} + DEPENDS ${DEP_NAME}) + add_dependencies(${TARGET_NAME} ${DEP_NAME} ${custom_target_name}) + endif(WITH_TESTING) +endfunction(add_run_test_dependency) + +# find all third_party modules is used for paddle static library +# for reduce the dependency when building the inference libs. +set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY) +function(find_fluid_thirdparties TARGET_NAME) + get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) + string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path + ${__target_path}) + string(FIND "${__target_path}" "third_party" pos) + if(pos GREATER 1) + get_property(fluid_ GLOBAL PROPERTY FLUID_THIRD_PARTY) + set(fluid_third_partys ${fluid_third_partys} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY "${fluid_third_partys}") + endif() +endfunction(find_fluid_thirdparties) + +function(cinn_merge_static_libs TARGET_NAME) + set(libs ${ARGN}) + list(REMOVE_DUPLICATES libs) + + # Get all propagation dependencies from the merged libraries + foreach(lib ${libs}) + list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) + endforeach() + if(libs_deps) + list(REMOVE_DUPLICATES libs_deps) + endif() + + # To produce a library we need at least one source file. + # It is created by add_custom_command below and will helps + # also help to track dependencies. + set(target_SRCS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + + if(APPLE) # Use OSX's libtool to merge archives + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command( + OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs}) + + # Generate dummy staic lib + file(WRITE ${target_SRCS} + "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + foreach(lib ${libs}) + # Get the file names of the libraries to be merged + set(libfiles ${libfiles} $) + endforeach() + add_custom_command( + TARGET ${TARGET_NAME} + POST_BUILD + COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" + COMMAND /usr/bin/libtool -static -o + "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}) + endif(APPLE) + if(LINUX + )# general UNIX: use "ar" to extract objects and re-add to a common lib + set(target_DIR ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}.dir) + + foreach(lib ${libs}) + set(objlistfile ${target_DIR}/${lib}.objlist + )# list of objects in the input library + set(objdir ${target_DIR}/${lib}.objdir) + + add_custom_command( + OUTPUT ${objdir} + COMMAND ${CMAKE_COMMAND} -E make_directory ${objdir} + DEPENDS ${lib}) + + add_custom_command( + OUTPUT ${objlistfile} + COMMAND ${CMAKE_AR} -x "$" + COMMAND ${CMAKE_AR} -t "$" > ${objlistfile} + DEPENDS ${lib} ${objdir} + WORKING_DIRECTORY ${objdir}) + + list(APPEND target_OBJS "${objlistfile}") + endforeach() + + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command( + OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs} ${target_OBJS}) + + # Generate dummy static lib + file(WRITE ${target_SRCS} + "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + # Get the file name of the generated library + set(target_LIBNAME "$") + + add_custom_command( + TARGET ${TARGET_NAME} + POST_BUILD + COMMAND ${CMAKE_AR} crs ${target_LIBNAME} `find ${target_DIR} -name '*.o'` + COMMAND ${CMAKE_RANLIB} ${target_LIBNAME} + WORKING_DIRECTORY ${target_DIR}) + endif(LINUX) + if(WIN32) + + # windows do not support gcc/nvcc combined compiling. Use msvc lib.exe to merge libs. + # Make the generated dummy source file depended on all static input + # libs. If input lib changes,the source file is touched + # which causes the desired effect (relink). + add_custom_command( + OUTPUT ${target_SRCS} + COMMAND ${CMAKE_COMMAND} -E touch ${target_SRCS} + DEPENDS ${libs}) + + # Generate dummy static lib + file(WRITE ${target_SRCS} + "const char *dummy_${TARGET_NAME} = \"${target_SRCS}\";") + add_library(${TARGET_NAME} STATIC ${target_SRCS}) + target_link_libraries(${TARGET_NAME} ${libs_deps}) + + foreach(lib ${libs}) + # Get the file names of the libraries to be merged + set(libfiles ${libfiles} $) + endforeach() + # msvc will put library in directory of "/Release/xxxlib" by default + # COMMAND cmake -E remove "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/${TARGET_NAME}.lib" + add_custom_command( + TARGET ${TARGET_NAME} + POST_BUILD + COMMAND cmake -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}" + COMMAND + lib + /OUT:${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/lib${TARGET_NAME}.lib + ${libfiles}) + endif(WIN32) +endfunction(cinn_merge_static_libs) + +# Modification of standard 'protobuf_generate_cpp()' with protobuf-lite support +# Usage: +# paddle_protobuf_generate_cpp( ) + +function(paddle_protobuf_generate_cpp SRCS HDRS) + if(NOT ARGN) + message( + SEND_ERROR + "Error: paddle_protobuf_generate_cpp() called without any proto files") + return() + endif() + + set(${SRCS}) + set(${HDRS}) + + foreach(FIL ${ARGN}) + get_filename_component(ABS_FIL ${FIL} ABSOLUTE) + get_filename_component(FIL_WE ${FIL} NAME_WE) + + set(_protobuf_protoc_src "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") + set(_protobuf_protoc_hdr "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") + list(APPEND ${SRCS} "${_protobuf_protoc_src}") + list(APPEND ${HDRS} "${_protobuf_protoc_hdr}") + + add_custom_command( + OUTPUT "${_protobuf_protoc_src}" "${_protobuf_protoc_hdr}" + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}" + COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} -I${CMAKE_SOURCE_DIR} --cpp_out + "${CMAKE_BINARY_DIR}" ${ABS_FIL} + DEPENDS ${ABS_FIL} protoc + COMMENT "Running C++ protocol buffer compiler on ${FIL}" + VERBATIM) + endforeach() + + set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) + set(${SRCS} + ${${SRCS}} + PARENT_SCOPE) + set(${HDRS} + ${${HDRS}} + PARENT_SCOPE) +endfunction() + +function(cinn_proto_library TARGET_NAME) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(cinn_proto_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + set(proto_srcs) + set(proto_hdrs) + paddle_protobuf_generate_cpp(proto_srcs proto_hdrs ${cinn_proto_library_SRCS}) + cinn_cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS + ${cinn_proto_library_DEPS} protobuf) + set("${TARGET_NAME}_HDRS" + ${proto_hdrs} + PARENT_SCOPE) + set("${TARGET_NAME}_SRCS" + ${proto_srcs} + PARENT_SCOPE) +endfunction() + +function(common_link TARGET_NAME) + if(WITH_PROFILER) + target_link_libraries(${TARGET_NAME} gperftools::profiler) + endif() + + if(WITH_JEMALLOC) + target_link_libraries(${TARGET_NAME} jemalloc::jemalloc) + endif() +endfunction() + +# This method is borrowed from Paddle-Lite. +function(download_and_uncompress INSTALL_DIR URL FILENAME) + message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") + string(REGEX REPLACE "[-%.]" "_" FILENAME_EX ${FILENAME}) + set(EXTERNAL_PROJECT_NAME "extern_lite_download_${FILENAME_EX}") + set(UNPACK_DIR "${INSTALL_DIR}/src/${EXTERNAL_PROJECT_NAME}") + ExternalProject_Add( + ${EXTERNAL_PROJECT_NAME} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${INSTALL_DIR} + DOWNLOAD_COMMAND + wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} + ${URL}/${FILENAME} && ${CMAKE_COMMAND} -E tar xzf + ${INSTALL_DIR}/${FILENAME} + DOWNLOAD_DIR ${INSTALL_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "") +endfunction() + +function(gather_srcs SRC_GROUP) + set(options) + set(oneValueArgs) + set(multiValueArgs "SRCS") + cmake_parse_arguments(prefix "" "" "${multiValueArgs}" ${ARGN}) + foreach(cpp ${prefix_SRCS}) + set(${SRC_GROUP} + "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" + CACHE INTERNAL "") + endforeach() +endfunction() + +function(core_gather_headers) + file( + GLOB includes + LIST_DIRECTORIES false + RELATIVE ${CMAKE_SOURCE_DIR} + *.h) + + foreach(header ${includes}) + set(core_includes + "${core_includes};${header}" + CACHE INTERNAL "") + endforeach() +endfunction() diff --git a/cmake/cinn/export.map b/cmake/cinn/export.map new file mode 100644 index 0000000000000..0b1aff5de9c00 --- /dev/null +++ b/cmake/cinn/export.map @@ -0,0 +1,6 @@ +{ + global: + RegisterKernels; + local: + *; +}; diff --git a/cmake/cinn/external/absl.cmake b/cmake/cinn/external/absl.cmake new file mode 100644 index 0000000000000..93c70c54959d4 --- /dev/null +++ b/cmake/cinn/external/absl.cmake @@ -0,0 +1,78 @@ +include(ExternalProject) + +set(ABSL_SOURCES_DIR ${CINN_THIRD_PARTY_PATH}/absl) +set(ABSL_INSTALL_DIR ${CINN_THIRD_PARTY_PATH}/install/absl) + +set(ABSL_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + +set(ABSL_REPOSITORY "https://github.com/abseil/abseil-cpp.git") +set(ABSL_TAG "20210324.2") + +set(OPTIONAL_ARGS + "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" + "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" + "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" + "-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}" + "-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}" + "-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}" + "-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}") + +ExternalProject_Add( + external_absl + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS gflags + GIT_REPOSITORY ${ABSL_REPOSITORY} + GIT_TAG ${ABSL_TAG} + PREFIX ${ABSL_SOURCES_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS ${OPTIONAL_ARGS} + -DCMAKE_INSTALL_PREFIX=${ABSL_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR=${ABSL_INSTALL_DIR}/lib + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DWITH_GFLAGS=ON + -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags + -DBUILD_TESTING=OFF + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS + -DCMAKE_INSTALL_PREFIX:PATH=${ABSL_INSTALL_DIR} + -DCMAKE_INSTALL_LIBDIR:PATH=${ABSL_INSTALL_DIR}/lib + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_base.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_hash.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_wyhash.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_city.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_strings.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_throw_delegate.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_bad_any_cast_impl.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_bad_optional_access.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_bad_variant_access.a + BUILD_BYPRODUCTS ${ABSL_INSTALL_DIR}/lib/libabsl_raw_hash_set.a) + +# It may be more convinent if we just include all absl libs +set(ABSL_LIB_NAMES + hash + wyhash + city + strings + throw_delegate + bad_any_cast_impl + bad_optional_access + bad_variant_access + raw_hash_set) +set(ABSL_LIBS "") + +add_library(absl STATIC IMPORTED GLOBAL) +set_property(TARGET absl PROPERTY IMPORTED_LOCATION + ${ABSL_INSTALL_DIR}/lib/libabsl_base.a) + +if(NOT USE_PREBUILD_EXTERNAL) + add_dependencies(absl external_absl) +endif() +foreach(lib_name ${ABSL_LIB_NAMES}) + target_link_libraries(absl + INTERFACE ${ABSL_INSTALL_DIR}/lib/libabsl_${lib_name}.a) +endforeach() +include_directories(${ABSL_INSTALL_DIR}/include) diff --git a/cmake/cinn/external/boost.cmake b/cmake/cinn/external/boost.cmake new file mode 100644 index 0000000000000..773b2f89f1704 --- /dev/null +++ b/cmake/cinn/external/boost.cmake @@ -0,0 +1,65 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + +set(BOOST_PROJECT "extern_boost") +# To release PaddlePaddle as a pip package, we have to follow the +# manylinux1 standard, which features as old Linux kernels and +# compilers as possible and recommends CentOS 5. Indeed, the earliest +# CentOS version that works with NVIDIA CUDA is CentOS 6. And a new +# version of boost, say, 1.66.0, doesn't build on CentOS 6. We +# checked that the devtools package of CentOS 6 installs boost 1.41.0. +# So we use 1.41.0 here. +set(BOOST_VER "1.41.0") +set(BOOST_TAR + "boost_1_41_0" + CACHE STRING "" FORCE) +set(BOOST_URL + "http://paddlepaddledeps.bj.bcebos.com/${BOOST_TAR}.tar.gz" + CACHE STRING "" FORCE) + +message(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") + +set(BOOST_SOURCES_DIR ${CINN_THIRD_PARTY_PATH}/boost) +set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}") + +set(BOOST_INCLUDE_DIR + "${BOOST_DOWNLOAD_DIR}" + CACHE PATH "boost include directory." FORCE) +set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) +include_directories(${BOOST_INCLUDE_DIR}) + +ExternalProject_Add( + ${BOOST_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + DOWNLOAD_DIR ${BOOST_DOWNLOAD_DIR} + URL ${BOOST_URL} + DOWNLOAD_NO_PROGRESS 1 + PREFIX ${BOOST_SOURCES_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + UPDATE_COMMAND "") + +if(${CMAKE_VERSION} VERSION_LESS "3.3.0" OR NOT WIN32) + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/boost_dummy.c) + file(WRITE ${dummyfile} "const char *dummy = \"${dummyfile}\";") + add_library(boost STATIC ${dummyfile}) +else() + add_library(boost INTERFACE) +endif() + +add_dependencies(boost ${BOOST_PROJECT}) +set(Boost_INCLUDE_DIR ${BOOST_INCLUDE_DIR}) diff --git a/cmake/cinn/external/ginac.cmake b/cmake/cinn/external/ginac.cmake new file mode 100644 index 0000000000000..5c31ac32fd790 --- /dev/null +++ b/cmake/cinn/external/ginac.cmake @@ -0,0 +1,36 @@ +include(ExternalProject) + +# gmp-6.2.1 https://gmplib.org/download/gmp/gmp-6.2.1.tar.xz +# cln-1.3.6 https://www.ginac.de/CLN/cln-1.3.6.tar.bz2 +# ginac-1.8.1 https://www.ginac.de/ginac-1.8.1.tar.bz2 +# all build with CFLAGS="-fPIC -DPIC" CXXFLAGS="-fPIC -DPIC" --enable-static=yes + +set(GINAC_DOWNLOAD_URL + https://paddle-inference-dist.bj.bcebos.com/CINN/ginac-1.8.1_cln-1.3.6_gmp-6.2.1.tar.gz +) +set(GINAC_MD5 ebc3e4b7770dd604777ac3f01bfc8b06) + +ExternalProject_Add( + external_ginac + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${GINAC_DOWNLOAD_URL} + URL_MD5 ${GINAC_MD5} + PREFIX ${CINN_THIRD_PARTY_PATH}/ginac + SOURCE_DIR ${CINN_THIRD_PARTY_PATH}/install/ginac + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libginac.a + BUILD_BYPRODUCTS ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libcln.a + BUILD_BYPRODUCTS ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libgmp.a) + +add_library(ginac STATIC IMPORTED GLOBAL) +add_dependencies(ginac external_ginac) +set_property( + TARGET ginac PROPERTY IMPORTED_LOCATION + ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libginac.a) +target_link_libraries( + ginac INTERFACE ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libcln.a + ${CINN_THIRD_PARTY_PATH}/install/ginac/lib/libgmp.a) +include_directories(${CINN_THIRD_PARTY_PATH}/install/ginac/include) diff --git a/cmake/cinn/external/isl.cmake b/cmake/cinn/external/isl.cmake new file mode 100644 index 0000000000000..a78dee350a5ad --- /dev/null +++ b/cmake/cinn/external/isl.cmake @@ -0,0 +1,32 @@ +include(ExternalProject) + +# isl https://github.com/inducer/ISL +# commit-id 6a1760fe46967cda2a06387793a6b7d4a0876581 +# depends on llvm f9dc2b7079350d0fed3bb3775f496b90483c9e42 +# depends on gmp-6.2.1 +# static build +# CPPFLAGS="-fPIC -DPIC" ./configure --with-gmp-prefix= --with-clang-prefix= --enable-shared=no --enable-static=yes + +set(ISL_DOWNLOAD_URL + https://paddle-inference-dist.bj.bcebos.com/CINN/isl-6a1760fe.tar.gz) +set(ISL_MD5 fff10083fb79d394b8a7b7b2089f6183) + +ExternalProject_Add( + external_isl + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${ISL_DOWNLOAD_URL} + URL_MD5 ${ISL_MD5} + PREFIX ${CINN_THIRD_PARTY_PATH}/isl + SOURCE_DIR ${CINN_THIRD_PARTY_PATH}/install/isl + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${CINN_THIRD_PARTY_PATH}/install/isl/lib/libisl.a) + +add_library(isl STATIC IMPORTED GLOBAL) +set_property( + TARGET isl PROPERTY IMPORTED_LOCATION + ${CINN_THIRD_PARTY_PATH}/install/isl/lib/libisl.a) +add_dependencies(isl external_isl) +include_directories(${CINN_THIRD_PARTY_PATH}/install/isl/include) diff --git a/cmake/cinn/external/jitify.cmake b/cmake/cinn/external/jitify.cmake new file mode 100644 index 0000000000000..b04d64b12b8fb --- /dev/null +++ b/cmake/cinn/external/jitify.cmake @@ -0,0 +1,28 @@ +if(NOT WITH_GPU) + set(JITIFY_FOUND OFF) + return() +endif() + +include(ExternalProject) + +set(JITIFY_SOURCE_PATH ${CINN_THIRD_PARTY_PATH}/install/jitify) + +ExternalProject_Add( + external_jitify + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/NVIDIA/jitify.git" + GIT_TAG 57de649139c866eb83acacfe50c92ad7c6278776 + GIT_TAG master + PREFIX ${CINN_THIRD_PARTY_PATH}/jitify + SOURCE_DIR ${JITIFY_SOURCE_PATH} + CONFIGURE_COMMAND "" + PATCH_COMMAND "" + BUILD_COMMAND "" + UPDATE_COMMAND "" + INSTALL_COMMAND "") + +include_directories(${JITIFY_SOURCE_PATH}) + +add_library(extern_jitify INTERFACE) +add_dependencies(extern_jitify external_jitify) +set(jitify_deps extern_jitify) diff --git a/cmake/cinn/external/llvm.cmake b/cmake/cinn/external/llvm.cmake new file mode 100644 index 0000000000000..29ab0967e3053 --- /dev/null +++ b/cmake/cinn/external/llvm.cmake @@ -0,0 +1,129 @@ +include(FetchContent) + +# set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) +# set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) + +set(LLVM_DOWNLOAD_URL + https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11-glibc2.17.tar.gz) +set(LLVM_MD5 33c7d3cc6d370585381e8d90bd7c2198) + +set(FETCHCONTENT_BASE_DIR ${CINN_THIRD_PARTY_PATH}/llvm) +set(FETCHCONTENT_QUIET OFF) +FetchContent_Declare( + external_llvm + URL ${LLVM_DOWNLOAD_URL} + URL_MD5 ${LLVM_MD5} + PREFIX ${CINN_THIRD_PARTY_PATH}/llvm SOURCE_DIR + ${CINN_THIRD_PARTY_PATH}/install/llvm) +if(NOT LLVM_PATH) + FetchContent_GetProperties(external_llvm) + if(NOT external_llvm_POPULATED) + FetchContent_Populate(external_llvm) + endif() + set(LLVM_PATH ${CINN_THIRD_PARTY_PATH}/install/llvm) + set(LLVM_DIR ${CINN_THIRD_PARTY_PATH}/install/llvm/lib/cmake/llvm) + set(MLIR_DIR ${CINN_THIRD_PARTY_PATH}/install/llvm/lib/cmake/mlir) +else() + set(LLVM_DIR ${LLVM_PATH}/lib/cmake/llvm) + set(MLIR_DIR ${LLVM_PATH}/lib/cmake/mlir) +endif() + +if(${CMAKE_CXX_COMPILER} STREQUAL "clang++") + set(CMAKE_EXE_LINKER_FLAGS + "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi") +endif() + +message(STATUS "set LLVM_DIR: ${LLVM_DIR}") +message(STATUS "set MLIR_DIR: ${MLIR_DIR}") +find_package(LLVM REQUIRED CONFIG HINTS ${LLVM_DIR}) +find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR}) +find_package(ZLIB REQUIRED) + +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(AddLLVM) + +include_directories(${LLVM_INCLUDE_DIRS}) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +include(AddLLVM) +include(TableGen) +include(AddMLIR) + +message(STATUS "Found MLIR: ${MLIR_DIR}") +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +# To build with MLIR, the LLVM is build from source code using the following flags: + +#[==[ +cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_BUILD_EXAMPLES=OFF \ + -DLLVM_TARGETS_TO_BUILD="X86" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_ZLIB=OFF \ + -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_ENABLE_TERMINFO=OFF \ + -DCMAKE_INSTALL_PREFIX=./install +#]==] + +# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) +# Update: to build llvm in manylinux docker with glibc-2.17, and use it in manylinux and ubuntu docker, +# the patch https://gist.github.com/zhiqiu/6e8d969176dce13d98fd15338a16265e is needed. + +add_definitions(${LLVM_DEFINITIONS}) + +llvm_map_components_to_libnames( + llvm_libs + Support + Core + irreader + X86 + executionengine + orcjit + mcjit + all + codegen) + +message(STATUS "LLVM libs: ${llvm_libs}") + +get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS) +message(STATUS "MLIR libs: ${mlir_libs}") +add_definitions(${LLVM_DEFINITIONS}) + +# The minimum needed libraries for MLIR IR parse and transform. +set(MLIR_IR_LIBS + MLIRAnalysis + MLIRStandardOps + MLIRPass + MLIRParser + MLIRDialect + MLIRIR + MLIROptLib) + +# tb_base is the name of a xxx.td file (without the .td suffix) +function(mlir_tablegen_on td_base) + set(options) + set(oneValueArgs DIALECT) + cmake_parse_arguments(mlir_tablegen_on "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(LLVM_TARGET_DEFINITIONS ${td_base}.td) + mlir_tablegen(${td_base}.hpp.inc -gen-op-decls) + mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) + if(mlir_tablegen_on_DIALECT) + mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls + -dialect=${mlir_tablegen_on_DIALECT}) + endif() + add_public_tablegen_target(${td_base}_IncGen) + add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) +endfunction() + +function(mlir_add_rewriter td_base) + set(LLVM_TARGET_DEFINITIONS ${td_base}.td) + mlir_tablegen(${td_base}.hpp.inc -gen-rewriters + "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") + add_public_tablegen_target(${td_base}_IncGen) + add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) +endfunction() diff --git a/cmake/cinn/external/openmp.cmake b/cmake/cinn/external/openmp.cmake new file mode 100644 index 0000000000000..2a0194636d6c2 --- /dev/null +++ b/cmake/cinn/external/openmp.cmake @@ -0,0 +1,37 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(USE_OPENMP STREQUAL "gnu") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + add_definitions(-DCINN_USE_OPENMP) + set(WITH_OPENMP ON) + message(STATUS "Build with OpenMP ${OpenMP_CXX_LIBRARIES}") + message(STATUS "CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS}) + else() + set(WITH_OPENMP OFF) + endif() +elseif(USE_OPENMP STREQUAL "intel") + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + message(STATUS "CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS}) + add_definitions(-DCINN_USE_OPENMP) + set(WITH_OPENMP ON) + message(STATUS "Build with OpenMP " ${MKLML_IOMP_LIB}) + else() + set(WITH_OPENMP OFF) + endif() +endif() diff --git a/cmake/cinn/llvm.cmake b/cmake/cinn/llvm.cmake new file mode 100644 index 0000000000000..4fc274e6983cd --- /dev/null +++ b/cmake/cinn/llvm.cmake @@ -0,0 +1,86 @@ +if(${CMAKE_CXX_COMPILER} STREQUAL "clang++") + set(CMAKE_EXE_LINKER_FLAGS + "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi") +endif() + +message(STATUS "set LLVM_DIR: ${LLVM_DIR}") +message(STATUS "set MLIR_DIR: ${MLIR_DIR}") +find_package(LLVM REQUIRED CONFIG HINTS ${LLVM_DIR}) +find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR}) +find_package(ZLIB REQUIRED) + +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(AddLLVM) + +include_directories(${LLVM_INCLUDE_DIRS}) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +include(AddLLVM) +include(TableGen) +include(AddMLIR) + +message(STATUS "Found MLIR: ${MLIR_DIR}") +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +# To build with MLIR, the LLVM is build from source code using the following flags: + +#[==[ +cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_BUILD_EXAMPLES=OFF \ + -DLLVM_TARGETS_TO_BUILD="X86" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_ZLIB=OFF \ + -DLLVM_ENABLE_RTTI=ON \ +#]==] +# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) + +add_definitions(${LLVM_DEFINITIONS}) + +llvm_map_components_to_libnames( + llvm_libs + Support + Core + irreader + X86 + executionengine + orcjit + mcjit + all + codegen) + +message(STATUS "LLVM libs: ${llvm_libs}") + +get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS) +message(STATUS "MLIR libs: ${mlir_libs}") +add_definitions(${LLVM_DEFINITIONS}) + +# The minimum needed libraries for MLIR IR parse and transform. +set(MLIR_IR_LIBS + MLIRAnalysis + MLIRStandardOps + MLIRPass + MLIRParser + MLIRDialect + MLIRIR + MLIROptLib) + +# tb_base is the name of a xxx.td file (without the .td suffix) +function(mlir_tablegen_on td_base) + set(options) + set(oneValueArgs DIALECT) + cmake_parse_arguments(mlir_tablegen_on "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + set(LLVM_TARGET_DEFINITIONS ${td_base}.td) + mlir_tablegen(${td_base}.hpp.inc -gen-op-decls) + mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) + if(mlir_tablegen_on_DIALECT) + mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls + -dialect=${mlir_tablegen_on_DIALECT}) + endif() + add_public_tablegen_target(${td_base}_IncGen) + add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) +endfunction() diff --git a/cmake/cinn/nvrtc.cmake b/cmake/cinn/nvrtc.cmake new file mode 100644 index 0000000000000..987bebfab0c05 --- /dev/null +++ b/cmake/cinn/nvrtc.cmake @@ -0,0 +1,24 @@ +if(NOT WITH_GPU) + return() +endif() + +find_package(PkgConfig) + +find_library( + CUDA_NVRTC_LIB libnvrtc nvrtc + HINTS "${CUDA_TOOLKIT_ROOT_DIR}/lib64" "${LIBNVRTC_LIBRARY_DIR}" + "${CUDA_TOOLKIT_ROOT_DIR}/lib/x64" /usr/lib64 /usr/local/cuda/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(LibNVRTC DEFAULT_MSG CUDA_NVRTC_LIB) + +message(STATUS "found NVRTC: ${CUDA_NVRTC_LIB}") + +mark_as_advanced(CUDA_NVRTC_LIB) + +if(NOT LIBNVRTC_FOUND) + message( + FATAL_ERROR + "Cuda NVRTC Library not found: Specify the LIBNVRTC_LIBRARY_DIR where libnvrtc is located" + ) +endif() diff --git a/cmake/cinn/nvtx.cmake b/cmake/cinn/nvtx.cmake new file mode 100644 index 0000000000000..d5d2049a68d40 --- /dev/null +++ b/cmake/cinn/nvtx.cmake @@ -0,0 +1,53 @@ +if((NOT WITH_GPU) + OR WIN32 + OR APPLE) + set(NVTX_FOUND OFF) + return() +endif() + +set(NVTX_ROOT + "/usr" + CACHE PATH "NVTX ROOT") +find_path( + NVTX_INCLUDE_DIR nvToolsExt.h + PATHS ${NVTX_ROOT} ${NVTX_ROOT}/include $ENV{NVTX_ROOT} + $ENV{NVTX_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} + NO_DEFAULT_PATH) + +get_filename_component(__libpath_hint ${CUDA_CUDART_LIBRARY} PATH) + +set(TARGET_ARCH "x86_64") +if(NOT ${CMAKE_SYSTEM_PROCESSOR}) + set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR}) +endif() + +list( + APPEND + NVTX_CHECK_LIBRARY_DIRS + ${NVTX_ROOT} + ${NVTX_ROOT}/lib64 + ${NVTX_ROOT}/lib + ${NVTX_ROOT}/lib/${TARGET_ARCH}-linux-gnu + $ENV{NVTX_ROOT} + $ENV{NVTX_ROOT}/lib64 + $ENV{NVTX_ROOT}/lib + ${CUDA_TOOLKIT_ROOT_DIR} + ${CUDA_TOOLKIT_ROOT_DIR}/targets/${TARGET_ARCH}-linux/lib) + +find_library( + CUDA_NVTX_LIB + NAMES libnvToolsExt.so + PATHS ${NVTX_CHECK_LIBRARY_DIRS} ${NVTX_INCLUDE_DIR} ${__libpath_hint} + NO_DEFAULT_PATH + DOC "Path to the NVTX library.") + +if(NVTX_INCLUDE_DIR AND CUDA_NVTX_LIB) + set(NVTX_FOUND ON) +else() + set(NVTX_FOUND OFF) +endif() + +if(NVTX_FOUND) + include_directories(${NVTX_INCLUDE_DIR}) + add_definitions(-DCINN_WITH_NVTX) +endif() diff --git a/cmake/cinn/system.cmake b/cmake/cinn/system.cmake new file mode 100644 index 0000000000000..b7e8a760712fc --- /dev/null +++ b/cmake/cinn/system.cmake @@ -0,0 +1,106 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Detects the OS and sets appropriate variables. +# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is +# building for, but the host processor name like centos is necessary +# in some scenes to distinguish system for customization. +# +# for instance, protobuf libs path is /lib64 +# on CentOS, but /lib on other systems. + +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif() + +if(WIN32) + set(HOST_SYSTEM "win32") +else() + if(APPLE) + set(HOST_SYSTEM "macosx") + exec_program( + sw_vers ARGS + -productVersion + OUTPUT_VARIABLE HOST_SYSTEM_VERSION) + string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}") + if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET}) + # Set cache variable - end user may change this during ccmake or cmake-gui configure. + set(CMAKE_OSX_DEPLOYMENT_TARGET + ${MACOS_VERSION} + CACHE + STRING + "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value." + ) + endif() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + else() + + if(EXISTS "/etc/issue") + file(READ "/etc/issue" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + elseif(LINUX_ISSUE MATCHES "Debian") + set(HOST_SYSTEM "debian") + elseif(LINUX_ISSUE MATCHES "Ubuntu") + set(HOST_SYSTEM "ubuntu") + elseif(LINUX_ISSUE MATCHES "Red Hat") + set(HOST_SYSTEM "redhat") + elseif(LINUX_ISSUE MATCHES "Fedora") + set(HOST_SYSTEM "fedora") + endif() + + string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION + "${LINUX_ISSUE}") + endif() + + if(EXISTS "/etc/redhat-release") + file(READ "/etc/redhat-release" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + endif() + endif() + + if(NOT HOST_SYSTEM) + set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME}) + endif() + + endif() +endif() + +# query number of logical cores +cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES) + +mark_as_advanced(HOST_SYSTEM CPU_CORES) + +message( + STATUS + "Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}") +message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores") + +# external dependencies log output +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD + 0 # Wrap download in script to log output + LOG_UPDATE + 1 # Wrap update in script to log output + LOG_CONFIGURE + 1 # Wrap configure in script to log output + LOG_BUILD + 0 # Wrap build in script to log output + LOG_TEST + 1 # Wrap test in script to log output + LOG_INSTALL + 0 # Wrap install in script to log output +) diff --git a/cmake/cinn/version.cmake b/cmake/cinn/version.cmake new file mode 100644 index 0000000000000..6b5534ae9184f --- /dev/null +++ b/cmake/cinn/version.cmake @@ -0,0 +1,76 @@ +# Get the latest git tag. +set(CINN_VERSION $ENV{CINN_VERSION}) +set(tmp_version "HEAD") +set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") +set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") +while("${CINN_VERSION}" STREQUAL "") + # Check current branch name + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --abbrev-ref ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_BRANCH_NAME + RESULT_VARIABLE GIT_BRANCH_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT ${GIT_BRANCH_RESULT}) + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 --always + ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_TAG_NAME + RESULT_VARIABLE GIT_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT ${GIT_RESULT}) + # Check if current branch is release branch + if(${GIT_BRANCH_NAME} MATCHES "release/${TAG_VERSION_REGEX}") + # Check the tag is a correct version + if(${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}") + # if no tag was found, set CINN_VERSION to 0.0.0 to represent latest + set(CINN_VERSION "0.0.0") + elseif(${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") + string(REPLACE "v" "" CINN_VERSION ${GIT_TAG_NAME}) + else() # otherwise, get the previous git tag name. + set(tmp_version "${GIT_TAG_NAME}~1") + endif() + else() + execute_process( + COMMAND ${GIT_EXECUTABLE} describe --exact-match --tags ${tmp_version} + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_EXACT_TAG_NAME + RESULT_VARIABLE GIT_EXACT_TAG_RESULT + ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT ${GIT_EXACT_TAG_NAME}) + # Check if current branch is tag branch + if(${GIT_EXACT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}") + string(REPLACE "v" "" CINN_VERSION ${GIT_EXACT_TAG_NAME}) + else() + set(CINN_VERSION "0.0.0") + endif() + else() + # otherwise, we always set CINN_VERSION to 0.0.0 to represent latest + set(CINN_VERSION "0.0.0") + endif() + endif() + else() + set(CINN_VERSION "0.0.0") + message(WARNING "Cannot add CINN version from git tag") + endif() + else() + set(CINN_VERSION "0.0.0") + message(WARNING "Cannot add CINN version for wrong git branch result") + endif() +endwhile() + +string(REPLACE "-" "." CINN_VER_LIST ${CINN_VERSION}) +string(REPLACE "." ";" CINN_VER_LIST ${CINN_VER_LIST}) +list(GET CINN_VER_LIST 0 CINN_MAJOR_VER) +list(GET CINN_VER_LIST 1 CINN_MINOR_VER) +list(GET CINN_VER_LIST 2 CINN_PATCH_VER) +math(EXPR CINN_VERSION_INTEGER "${CINN_MAJOR_VER} * 1000000 + + ${CINN_MINOR_VER} * 1000 + ${CINN_PATCH_VER}") + +add_definitions(-DCINN_VERSION=${CINN_VERSION}) +add_definitions(-DCINN_VERSION_INTEGER=${CINN_VERSION_INTEGER}) +message( + STATUS + "CINN version is ${CINN_VERSION} (major: ${CINN_MAJOR_VER}, minor: ${CINN_MINOR_VER}, patch: ${CINN_PATCH_VER})" +) diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake deleted file mode 100644 index 7d494ef516cae..0000000000000 --- a/cmake/external/cinn.cmake +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -if(NOT WITH_CINN) - return() -endif() - -if(NOT CINN_GIT_TAG) - set(CINN_GIT_TAG develop) -endif() - -message(STATUS "CINN version: " ${CINN_GIT_TAG}) - -# TODO(zhhsplendid): CINN has lots of warnings during early development. -# They will be treated as errors under paddle. We set no-error now and we will -# clean the code in the future. -add_definitions(-w) - -###################################### -# Build CINN from Git External Project -###################################### -include(ExternalProject) -set(CINN_PREFIX_DIR ${THIRD_PARTY_PATH}/CINN) -set(CINN_OPTIONAL_ARGS - -DPY_VERSION=${PY_VERSION} - -DWITH_CUDA=${WITH_GPU} - -DWITH_CUDNN=${WITH_GPU} - -DWITH_MKL_CBLAS=${WITH_MKL} - -DWITH_MKLDNN=${WITH_MKL} - -DPUBLISH_LIBS=ON - -DWITH_TESTING=ON - -DPYTHON_EXECUTABLE=${PYTHON_EXECUTABLE} - -DPYTHON_INCLUDE_DIR=${PYTHON_INCLUDE_DIR} - -DPYTHON_LIBRARIES=${PYTHON_LIBRARIES}) -set(CINN_BUILD_COMMAND ${CMAKE_COMMAND} --build . --target cinnapi -j) -set(CINN_BINARY_DIR ${CINN_PREFIX_DIR}/src/external_cinn-build) -set(CINN_LIB_NAME "libcinnapi.so") -set(CINN_LIB_LOCATION "${CINN_BINARY_DIR}/dist/cinn/lib") -set(CINN_LIB "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}") - -ExternalProject_Add( - external_cinn - ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "${GIT_URL}/PaddlePaddle/CINN.git" - GIT_TAG ${CINN_GIT_TAG} - PREFIX ${CINN_PREFIX_DIR} - BUILD_COMMAND ${CINN_BUILD_COMMAND} - INSTALL_COMMAND "" - CMAKE_ARGS ${CINN_OPTIONAL_ARGS} - BUILD_BYPRODUCTS ${CINN_LIB}) - -ExternalProject_Get_Property(external_cinn BINARY_DIR) -ExternalProject_Get_Property(external_cinn SOURCE_DIR) -set(CINN_SOURCE_DIR ${SOURCE_DIR}) - -message(STATUS "CINN BINARY_DIR: ${CINN_BINARY_DIR}") -message(STATUS "CINN SOURCE_DIR: ${CINN_SOURCE_DIR}") - -###################################### -# Add CINN's dependencies header files -###################################### - -# Add absl -set(ABSL_INCLUDE_DIR "${CINN_BINARY_DIR}/dist/third_party/absl/include") -include_directories(${ABSL_INCLUDE_DIR}) - -# Add isl -set(ISL_INCLUDE_DIR "${CINN_BINARY_DIR}/dist/third_party/isl/include") -include_directories(${ISL_INCLUDE_DIR}) - -# Add LLVM -set(LLVM_INCLUDE_DIR "${CINN_BINARY_DIR}/dist/third_party/llvm/include") -include_directories(${LLVM_INCLUDE_DIR}) - -###################################################### -# Put external_cinn and dependencies together as a lib -###################################################### - -set(CINN_INCLUDE_DIR "${CINN_BINARY_DIR}/dist/cinn/include") - -add_library(cinn SHARED IMPORTED GLOBAL) -set_target_properties(cinn PROPERTIES IMPORTED_LOCATION - "${CINN_LIB_LOCATION}/${CINN_LIB_NAME}") -include_directories(${CINN_INCLUDE_DIR}) -add_dependencies(cinn external_cinn) diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake index 6ce8290d72f42..1e0838145a63d 100644 --- a/cmake/external/pybind11.cmake +++ b/cmake/external/pybind11.cmake @@ -24,7 +24,7 @@ set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include) include_directories(${PYBIND_INCLUDE_DIR}) set(PYBIND_PATCH_COMMAND "") -if(NOT WIN32) +if(NOT WIN32 AND NOT CINN_ONLY) file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/pybind/cast.h.patch native_dst) # Note: [Why calling some `git` commands before `patch`?] diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 43f5604f2808c..d33b86944008a 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -260,6 +260,29 @@ if(${CMAKE_VERSION} VERSION_GREATER "3.5.2") endif() ########################### include third_party according to flags ############################### + +if(CINN_ONLY) + include(external/zlib) + include(external/gflags) + include(external/glog) + include(external/gtest) + include(external/protobuf) + if(WITH_PYTHON) + include(external/pybind11) + endif() + if(WITH_MKL) + include(external/mklml) + generate_dummy_static_lib(LIB_NAME "cinn_mklml" GENERATOR "mklml.cmake") + target_link_libraries(cinn_mklml ${MKLML_LIB} ${MKLML_IOMP_LIB}) + add_definitions(-DCINN_WITH_MKL_CBLAS) + endif() + if(WITH_MKLDNN) + include(external/mkldnn) + add_definitions(-DCINN_WITH_MKLDNN) + endif() + return() +endif() + include(external/zlib) # download, build, install zlib include(external/gflags) # download, build, install gflags include(external/glog) # download, build, install glog @@ -474,20 +497,6 @@ if(WITH_LITE) include(external/lite) endif() -if(WITH_CINN) - message(STATUS "Compile Paddle with CINN.") - include(external/cinn) - add_definitions(-DPADDLE_WITH_CINN) - if(WITH_GPU) - add_definitions(-DCINN_WITH_CUDA) - add_definitions(-DCINN_WITH_CUDNN) - endif() - if(WITH_MKL) - add_definitions(-DCINN_WITH_MKL_CBLAS) - add_definitions(-DCINN_WITH_MKLDNN) - endif() -endif() - if(WITH_CRYPTO) include(external/cryptopp) # download, build, install cryptopp list(APPEND third_party_deps extern_cryptopp) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 8d9073b398417..cd7dc7e12f2a3 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -1,3 +1,43 @@ +if(CINN_ONLY) + file(GLOB_RECURSE CINN_PY_FILES ${PROJECT_SOURCE_DIR}/python/cinn/*.py) + set(CINN_PYTHON_DIR ${PROJECT_SOURCE_DIR}/python/cinn) + set(CINN_CORE_API ${CMAKE_BINARY_DIR}/python/cinn/core_api.so) + + if(WITH_GPU) + set(PACKAGE_NAME "cinn-gpu") + else() + set(PACKAGE_NAME "cinn") + endif() + set(SETUP_LOG_FILE "setup.py.log") + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup_cinn.py.in + ${CMAKE_CURRENT_BINARY_DIR}/setup_cinn.py) + + if(NOT PYTHON_EXECUTABLE) + find_package(PythonInterp ${PY_VERSION} REQUIRED) + find_package(PythonLibs ${PY_VERSION} REQUIRED) + endif() + + message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") + + # There may be a link file called core_api.so under the dir ${CINN_PYTHON_DIR} due to the `mac_doc` + # function defined in build.sh. So, we need to copy the directory ${CINN_PYTHON_DIR} first and + # then core_api.so. + add_custom_command( + OUTPUT ${CINN_CORE_API} POST_BUILD + COMMAND cp -rf --remove-destination ${CINN_PYTHON_DIR} + ${CMAKE_BINARY_DIR}/python/cinn + COMMAND cp --remove-destination + ${CMAKE_BINARY_DIR}/paddle/cinn/pybind/core_api.so ${CINN_CORE_API} + COMMAND cd ${CMAKE_CURRENT_BINARY_DIR} && ${PYTHON_EXECUTABLE} setup_cinn.py + bdist_wheel + DEPENDS core_api ${CINN_PY_FILES}) + + add_custom_target(COPY_CINN_CORE_API ALL DEPENDS ${CINN_CORE_API} + ${CINN_PY_FILES}) + + return() +endif() + file(GLOB UTILS_PY_FILES . ./paddle/legacy/utils/*.py) file(GLOB_RECURSE FLUID_PY_FILES ./paddle/fluid/*.py) set(PY_FILES paddle/__init__.py ${UTILS_PY_FILES} ${FLUID_PY_FILES}) diff --git a/python/setup_cinn.py.in b/python/setup_cinn.py.in new file mode 100644 index 0000000000000..fbdaac8625840 --- /dev/null +++ b/python/setup_cinn.py.in @@ -0,0 +1,181 @@ +import os +import re +import sys +import shutil +import errno +from contextlib import contextmanager +from setuptools import setup + +def set_rpath(lib, rpath): + command = "patchelf --set-rpath '{}' {}".format(rpath, lib) + if os.system(command) != 0: + raise Exception("patch {} failed, command: {}".format(lib, command)) + +def git_commit(): + try: + cmd = ['git', 'rev-parse', 'HEAD'] + git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE, + cwd="${PROJECT_SOURCE_DIR}").communicate()[0].strip() + except: + git_commit = b'Unknown' + git_commit = git_commit.decode() + return str(git_commit) + +def _get_version_detail(idx): + assert idx < 3, "vesion info consists of %(major)d.%(minor)d.%(patch)d, \ + so detail index must less than 3" + + if re.match('${TAG_VERSION_REGEX}', '${PADDLE_VERSION}'): + version_details = '${PADDLE_VERSION}'.split('.') + + if len(version_details) >= 3: + return version_details[idx] + + return 0 + +def get_major(): + return int(_get_version_detail(0)) + +def get_minor(): + return int(_get_version_detail(1)) + +def get_patch(): + return str(_get_version_detail(2)) + +def get_cuda_version(): + if '${WITH_GPU}' == 'ON': + return '${CUDA_VERSION}' + else: + return 'False' + +def get_cudnn_version(): + if '${WITH_GPU}' == 'ON': + temp_cudnn_version = '' + if '${CUDNN_MAJOR_VERSION}': + temp_cudnn_version += '${CUDNN_MAJOR_VERSION}' + if '${CUDNN_MINOR_VERSION}': + temp_cudnn_version += '.${CUDNN_MINOR_VERSION}' + if '${CUDNN_PATCHLEVEL_VERSION}': + temp_cudnn_version += '.${CUDNN_PATCHLEVEL_VERSION}' + return temp_cudnn_version + else: + return 'False' + +def is_taged(): + try: + cmd = ['git', 'describe', '--exact-match', '--tags', 'HEAD', '2>/dev/null'] + git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE, cwd="${PROJECT_SOURCE_DIR}").communicate()[0].strip() + git_tag = git_tag.decode() + except: + return False + + if str(git_tag).replace('v', '') == '${CINN_VERSION}': + return True + else: + return False + +def write_version_py(filename='cinn/version/info.py'): + cnt = '''# THIS FILE IS GENERATED FROM CINN SETUP.PY +# +full_version = '%(major)d.%(minor)d.%(patch)s' +major = '%(major)d' +minor = '%(minor)d' +patch = '%(patch)s' +cuda_version = '%(cuda)s' +cudnn_version = '%(cudnn)s' +istaged = %(istaged)s +commit = '%(commit)s' +with_mkl = '%(with_mkl)s' +''' + commit = git_commit() + + dirname = os.path.dirname(filename) + + try: + os.makedirs(dirname) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + with open(filename, 'w') as f: + f.write(cnt % { + 'major': get_major(), + 'minor': get_minor(), + 'patch': get_patch(), + 'version': '${CINN_VERSION}', + 'cuda': get_cuda_version(), + 'cudnn': get_cudnn_version(), + 'commit': commit, + 'istaged': is_taged(), + 'with_mkl': '${WITH_MKL}'}) + +write_version_py(filename='${CINN_BINARY_DIR}/python/cinn/version/info.py') + +if sys.platform != 'win32': + @contextmanager + def redirect_stdout(): + f_log = open('${SETUP_LOG_FILE}', 'w') + origin_stdout = sys.stdout + sys.stdout = f_log + yield + f_log = sys.stdout + sys.stdout = origin_stdout + f_log.close() +else: + @contextmanager + def redirect_stdout(): + yield + +libs_path = '${CMAKE_BINARY_DIR}/python/cinn/libs' + +cinnlibs = [] +package_data = {'cinn': ['core_api.so'], 'cinn.libs': []} + +if '${WITH_MKL}' == 'ON': + cinnlibs.append('${MKLML_LIB}') + cinnlibs.append('${MKLML_IOMP_LIB}') + +if '${WITH_GPU}' == 'ON': + cinnlibs.append('${CMAKE_BINARY_DIR}/dist/cinn/include/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh') + cinnlibs.append('${CMAKE_BINARY_DIR}/dist/cinn/include/paddle/cinn/runtime/cuda/float16.h') + cinnlibs.append('${CMAKE_BINARY_DIR}/dist/cinn/include/paddle/cinn/runtime/cuda/bfloat16.h') + +for lib in cinnlibs: + shutil.copy(lib, libs_path) + libname = os.path.basename(lib) + if lib.endswith('so'): + set_rpath(os.path.join(libs_path, libname) , '$ORIGIN/') + package_data['cinn.libs'].append(libname) + +set_rpath('${CMAKE_BINARY_DIR}/python/cinn/core_api.so', '$ORIGIN/libs/') + +def git_commit(): + try: + cmd = ['git', 'rev-parse', 'HEAD'] + git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE, + cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip() + except: + git_commit = 'Unknown' + git_commit = git_commit.decode() + return str(git_commit) + +packages = ["cinn", + "cinn.auto_schedule", + "cinn.auto_schedule.cost_model", + "cinn.ir", + "cinn.libs", + "cinn.version" + ] + +with redirect_stdout(): + setup( + name='${PACKAGE_NAME}', + version='${CINN_VERSION}', + description='CINN: a Compiler Infrastructure for Neural Networks', + maintainer="PaddlePaddle", + maintainer_email="Paddle-better@baidu.com", + url='https://github.com/PaddlePaddle/Paddle', + license='Apache Software License', + packages=packages, + package_data=package_data + ) From d19edcad3f781ca8e15f72653dd8e018aa0ac114 Mon Sep 17 00:00:00 2001 From: 6clc Date: Thu, 15 Jun 2023 14:58:13 +0800 Subject: [PATCH 02/14] feat(cmake): add cmake of cinn python test --- test/cinn/CMakeLists.txt | 269 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 test/cinn/CMakeLists.txt diff --git a/test/cinn/CMakeLists.txt b/test/cinn/CMakeLists.txt new file mode 100644 index 0000000000000..3497958bf3d73 --- /dev/null +++ b/test/cinn/CMakeLists.txt @@ -0,0 +1,269 @@ +set(CINN_PYTHON_TEST_DIR ${PROJECT_SOURCE_DIR}/test/cinn) +set(CINN_CORE_API ${CMAKE_BINARY_DIR}/python/cinn/core_api.so) + +add_custom_command( + OUTPUT ${CINN_CORE_API} POST_BUILD + COMMAND cp --remove-destination + ${CMAKE_BINARY_DIR}/paddle/cinn/pybind/core_api.so ${CINN_CORE_API} + DEPENDS core_api ${CINN_PY_FILES}) + + +set(BASIC_TEST_NAMES + test_matmul + test_common + test_packed_func + test_pe_elementwise + test_pe_reduction + test_pe_transform + test_op_broadcast + # test_op_transform +) + +foreach(basic_test_name ${BASIC_TEST_NAMES}) + add_test( + NAME ${basic_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/${basic_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +endforeach() + +if(NOT ${WITH_GPU}) + # ADD_TEST(NAME test_op_nn + # COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} + # python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_op_nn.py WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + # ) +endif() + +if(WITH_GPU) + # TODO(thisjiang): revert test_cinn_frontend after fix inference mul problem + # ADD_TEST(NAME test_cinn_frontend + # COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} + # python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_frontend.py + # ${CMAKE_BINARY_DIR}/thirds/naive_mul_model + # ${CMAKE_BINARY_DIR}/thirds/multi_fc_model + # "${WITH_GPU}" WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + # ) + add_test( + NAME test_netbuilder + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_netbuilder.py "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +endif() + +#ADD_TEST(NAME test_computation_python +# COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} +# python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_computation.py +# ${CMAKE_BINARY_DIR}/thirds/naive_mul_model +# "${WITH_GPU}" WORKING_DIRECTORY ${CMAKE_BINARY_DIR} +#) + +#ADD_TEST(NAME test_cinn_ops_check +# COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} +# python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_ops.py "${WITH_GPU}" +# WORKING_DIRECTORY ${CMAKE_BINARY_DIR} +#) + +add_test( + NAME test_cinn_op_benchmark + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_op_benchmark.py "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + +if(WITH_GPU) + add_test( + NAME test_cinn_fake_resnet + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet.py + "${CMAKE_BINARY_DIR}/thirds/resnet_model" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_resnet18 + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet18.py + "${CMAKE_BINARY_DIR}/thirds/ResNet18" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_mobilenetV2 + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv2.py + "${CMAKE_BINARY_DIR}/thirds/MobileNetV2" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_efficientnet + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_efficientnet.py + "${CMAKE_BINARY_DIR}/thirds/EfficientNet" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_mobilenetV1 + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv1.py + "${CMAKE_BINARY_DIR}/thirds/MobilenetV1" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_resnet50 + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet50.py + "${CMAKE_BINARY_DIR}/thirds/ResNet50" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_cinn_real_squeezenet + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_squeezenet.py + "${CMAKE_BINARY_DIR}/thirds/SqueezeNet" "${WITH_GPU}" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_paddle_model_convertor + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/test_paddle_model_convertor.py --path + "${CMAKE_BINARY_DIR}/thirds/resnet_model" + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +endif() + +#ADD_TEST(NAME test_cinn_real_facedet +# COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} +# python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_facedet.py "${CMAKE_BINARY_DIR}/thirds/FaceDet" "${WITH_GPU}" +# WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +if(WITH_GPU) + file( + GLOB CINN_OP_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "ops/test_*.py") + set(EXCLUDE_OP test_conv2d_op) + + if(WITH_GPU) + add_test( + NAME test_conv2d_op + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/ops/test_conv2d_op.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endif() + + foreach(op_test_name ${EXCLUDE_OP}) + list(REMOVE_ITEM CINN_OP_TEST ops/${op_test_name}.py) + endforeach() + + foreach(op_test_name ${CINN_OP_TEST}) + string(REGEX REPLACE ".py" "" op_test_name ${op_test_name}) + add_test( + NAME ${op_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/${op_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + + # test op mapper + file( + GLOB CINN_OP_MAPPER_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "op_mappers/test_*.py") + set(EXCLUDE_OP_MAPPER test_mul_op test_conv2d_op) + + if(WITH_GPU) + add_test( + NAME test_mul_op_mapper + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/op_mappers/test_mul_op.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + + add_test( + NAME test_conv2d_op_mapper + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/op_mappers/test_conv2d_op.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endif() + + foreach(op_mapper_test_name ${EXCLUDE_OP_MAPPER}) + list(REMOVE_ITEM CINN_OP_MAPPER_TEST op_mappers/${op_mapper_test_name}.py) + endforeach() + + foreach(op_mapper_test_name ${CINN_OP_MAPPER_TEST}) + string(REGEX REPLACE ".py" "" op_mapper_test_name ${op_mapper_test_name}) + add_test( + NAME "${op_mapper_test_name}_mapper" + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/${op_mapper_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + + # test pass test + file( + GLOB CINN_PASS_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "passes/test_*.py") + + foreach(pass_test_name ${EXCLUDE_PASS}) + list(REMOVE_ITEM CINN_PASS_TEST passes/${pass_test_name}.py) + endforeach() + + foreach(pass_test_name ${CINN_PASS_TEST}) + string(REGEX REPLACE ".py" "" pass_test_name ${pass_test_name}) + add_test( + NAME ${pass_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/${pass_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + + file( + GLOB CINN_FUSION_TEST + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "fusion/test_*.py") + + foreach(fusion_test_name ${EXCLUDE_FUSION}) + list(REMOVE_ITEM CINN_FUSION_TEST fusion/${fusion_test_name}.py) + endforeach() + + foreach(fusion_test_name ${CINN_FUSION_TEST}) + string(REGEX REPLACE ".py" "" fusion_test_name ${fusion_test_name}) + add_test( + NAME ${fusion_test_name} + COMMAND + ${CMAKE_COMMAND} -E env + PYTHONPATH=${PROJECT_SOURCE_DIR}:$ENV{PYTHONPATH} python3 + ${CMAKE_CURRENT_SOURCE_DIR}/${fusion_test_name}.py + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) + endforeach() + +endif() From 796f1d5eb20ba0d3e3ac6977e1683a3403228aa7 Mon Sep 17 00:00:00 2001 From: 6clc Date: Thu, 15 Jun 2023 16:34:18 +0800 Subject: [PATCH 03/14] feat(cmake): add jit --- cmake/cinn/external/jitify.cmake | 1 - paddle/cinn/CMakeLists.txt | 21 + paddle/cinn/auto_schedule/CMakeLists.txt | 22 + .../auto_schedule/analysis/CMakeLists.txt | 5 + .../cinn/auto_schedule/analysis/analyze_ir.cc | 176 + .../cinn/auto_schedule/analysis/analyze_ir.h | 48 + .../auto_schedule/analysis/analyze_ir_test.cc | 181 + paddle/cinn/auto_schedule/auto_schedule.proto | 26 + paddle/cinn/auto_schedule/auto_tuner.cc | 163 + paddle/cinn/auto_schedule/auto_tuner.h | 79 + paddle/cinn/auto_schedule/auto_tuner_test.cc | 164 + .../auto_schedule/cost_model/CMakeLists.txt | 7 + .../cost_model/expr_cost_model.cc | 77 + .../cost_model/expr_cost_model.h | 45 + .../cinn/auto_schedule/cost_model/feature.cc | 175 + .../cinn/auto_schedule/cost_model/feature.h | 178 + .../cost_model/feature_extractor.cc | 299 ++ .../cost_model/feature_extractor.h | 60 + .../cost_model/feature_extractor_test.cc | 158 + .../auto_schedule/cost_model/feature_test.cc | 28 + .../cost_model/xgb_cost_model.cc | 135 + .../auto_schedule/cost_model/xgb_cost_model.h | 75 + .../cost_model/xgb_cost_model_test.cc | 69 + .../auto_schedule/database/CMakeLists.txt | 6 + .../cinn/auto_schedule/database/database.cc | 122 + paddle/cinn/auto_schedule/database/database.h | 102 + .../auto_schedule/database/database_test.cc | 70 + .../database/jsonfile_database.cc | 99 + .../database/jsonfile_database.h | 52 + .../database/jsonfile_database_test.cc | 214 ++ .../cinn/auto_schedule/measure/CMakeLists.txt | 6 + paddle/cinn/auto_schedule/measure/measure.h | 79 + .../auto_schedule/measure/measurer_test.cc | 127 + .../measure/schedule_measurer.cc | 77 + .../auto_schedule/measure/schedule_measurer.h | 44 + .../auto_schedule/measure/simple_builder.cc | 41 + .../auto_schedule/measure/simple_builder.h | 37 + .../auto_schedule/measure/simple_runner.cc | 227 ++ .../auto_schedule/measure/simple_runner.h | 43 + .../measure/simple_runner_test.cc | 139 + .../post_schedule_rule/CMakeLists.txt | 9 + .../post_schedule_rule/cooperative_process.cc | 70 + .../post_schedule_rule/cooperative_process.h | 34 + .../cooperative_process_test.cc | 199 ++ .../post_schedule_rule/post_schedule_rule.h | 38 + .../auto_schedule/search_space/CMakeLists.txt | 15 + .../search_space/auto_gen_rule/CMakeLists.txt | 24 + .../search_space/auto_gen_rule/auto_bind.cc | 163 + .../search_space/auto_gen_rule/auto_bind.h | 48 + .../auto_gen_rule/auto_bind_test.cc | 118 + .../auto_gen_rule/auto_gen_rule.cc | 41 + .../auto_gen_rule/auto_gen_rule.h | 84 + .../search_space/auto_gen_rule/auto_inline.cc | 214 ++ .../search_space/auto_gen_rule/auto_inline.h | 71 + .../auto_gen_rule/auto_inline_test.cc | 493 +++ .../search_space/auto_gen_rule/auto_unroll.cc | 120 + .../search_space/auto_gen_rule/auto_unroll.h | 54 + .../auto_gen_rule/auto_unroll_test.cc | 107 + .../auto_gen_rule/mix_rules_test.cc | 66 + .../auto_gen_rule/multi_level_tiling.cc | 401 +++ .../auto_gen_rule/multi_level_tiling.h | 138 + .../auto_gen_rule/multi_level_tiling_test.cc | 548 +++ .../search_space/auto_gen_rule/skip_rule.cc | 38 + .../search_space/auto_gen_rule/skip_rule.h | 45 + .../auto_gen_rule/skip_rule_test.cc | 122 + .../search_space/auto_gen_rule/test_helper.cc | 240 ++ .../search_space/auto_gen_rule/test_helper.h | 92 + .../search_space/block_sampler.cc | 92 + .../search_space/block_sampler.h | 115 + .../search_space/block_sampler_test.cc | 73 + .../search_space/rule_sampler.cc | 80 + .../auto_schedule/search_space/rule_sampler.h | 114 + .../search_space/rule_sampler_test.cc | 75 + .../search_space/search_space.cc | 301 ++ .../auto_schedule/search_space/search_space.h | 104 + .../search_space/search_space_test.cc | 21 + .../search_space/search_state.cc | 152 + .../auto_schedule/search_space/search_state.h | 87 + .../search_space/search_state_test.cc | 136 + .../search_strategy/CMakeLists.txt | 7 + .../search_strategy/evolutionary_search.cc | 302 ++ .../search_strategy/evolutionary_search.h | 146 + .../evolutionary_search_test.cc | 196 ++ .../mutate_rule/CMakeLists.txt | 8 + .../mutate_rule/mutate_rule.cc | 32 + .../search_strategy/mutate_rule/mutate_rule.h | 48 + .../mutate_rule/mutate_tile_size.cc | 142 + .../mutate_rule/mutate_tile_size.h | 33 + .../mutate_rule/mutate_tile_size_test.cc | 126 + paddle/cinn/auto_schedule/task/CMakeLists.txt | 12 + .../cinn/auto_schedule/task/task_creator.cc | 57 + paddle/cinn/auto_schedule/task/task_creator.h | 36 + .../auto_schedule/task/task_creator_test.cc | 72 + .../cinn/auto_schedule/task/task_optimizer.cc | 407 +++ .../cinn/auto_schedule/task/task_optimizer.h | 70 + .../cinn/auto_schedule/task/task_registry.h | 79 + .../auto_schedule/task/task_registry_test.cc | 105 + paddle/cinn/auto_schedule/task/tune_task.cc | 97 + paddle/cinn/auto_schedule/task/tune_task.h | 69 + .../cinn/auto_schedule/task/tune_task_test.cc | 339 ++ .../task_scheduler/CMakeLists.txt | 5 + .../task_scheduler/efficiency_priority.cc | 33 + .../task_scheduler/efficiency_priority.h | 39 + .../task_scheduler/round_robin.cc | 28 + .../task_scheduler/round_robin.h | 36 + .../task_scheduler/task_scheduler.cc | 46 + .../task_scheduler/task_scheduler.h | 67 + .../task_scheduler/task_scheduler_test.cc | 56 + .../cinn/auto_schedule/tests/CMakeLists.txt | 5 + .../tests/performance_comparison_test.cc | 310 ++ paddle/cinn/auto_schedule/tuning.h | 91 + paddle/cinn/backends/CMakeLists.txt | 67 + paddle/cinn/backends/_x86_builtin_source.cc | 378 +++ paddle/cinn/backends/codegen_c.cc | 868 +++++ paddle/cinn/backends/codegen_c.h | 127 + paddle/cinn/backends/codegen_c_test.cc | 436 +++ paddle/cinn/backends/codegen_c_x86.cc | 153 + paddle/cinn/backends/codegen_c_x86.h | 131 + paddle/cinn/backends/codegen_c_x86_test.cc | 77 + paddle/cinn/backends/codegen_cuda_dev.cc | 391 +++ paddle/cinn/backends/codegen_cuda_dev.h | 110 + .../backends/codegen_cuda_generate_test.cc | 68 + paddle/cinn/backends/codegen_cuda_host.cc | 173 + paddle/cinn/backends/codegen_cuda_host.h | 56 + paddle/cinn/backends/codegen_cuda_util.cc | 30 + paddle/cinn/backends/codegen_cuda_util.h | 140 + paddle/cinn/backends/codegen_debug_test.cc | 121 + paddle/cinn/backends/compiler.cc | 163 + paddle/cinn/backends/compiler.h | 94 + paddle/cinn/backends/compiler_test.cc | 196 ++ paddle/cinn/backends/cuda_util.cc | 56 + paddle/cinn/backends/cuda_util.h | 100 + paddle/cinn/backends/extern_func_emitter.cc | 81 + paddle/cinn/backends/extern_func_emitter.h | 134 + .../backends/extern_func_emitter_builtin.cc | 87 + .../backends/extern_func_emitter_builtin.h | 61 + .../cinn/backends/extern_func_jit_register.cc | 40 + .../cinn/backends/extern_func_jit_register.h | 161 + paddle/cinn/backends/extern_func_protos.cc | 66 + paddle/cinn/backends/extern_func_protos.h | 43 + paddle/cinn/backends/function_prototype.cc | 130 + paddle/cinn/backends/function_prototype.h | 130 + paddle/cinn/backends/generated1.cu | 15 + paddle/cinn/backends/generated_module1.cc | 15 + paddle/cinn/backends/ir_schedule_test.cc | 3019 +++++++++++++++++ paddle/cinn/backends/llvm/CMakeLists.txt | 41 + paddle/cinn/backends/llvm/codegen_llvm.cc | 1527 +++++++++ paddle/cinn/backends/llvm/codegen_llvm.h | 248 ++ .../cinn/backends/llvm/codegen_llvm_test.cc | 623 ++++ paddle/cinn/backends/llvm/codegen_x86.cc | 163 + paddle/cinn/backends/llvm/codegen_x86.h | 59 + paddle/cinn/backends/llvm/codegen_x86_test.cc | 73 + paddle/cinn/backends/llvm/execution_engine.cc | 250 ++ paddle/cinn/backends/llvm/execution_engine.h | 104 + .../backends/llvm/execution_engine_test.cc | 329 ++ .../backends/llvm/generate_runtime_llvm_ir.py | 57 + paddle/cinn/backends/llvm/ir_builder_mixin.h | 306 ++ paddle/cinn/backends/llvm/llvm_intrin_rule.h | 177 + paddle/cinn/backends/llvm/llvm_optimizer.cc | 166 + paddle/cinn/backends/llvm/llvm_optimizer.h | 43 + paddle/cinn/backends/llvm/llvm_util.cc | 146 + paddle/cinn/backends/llvm/llvm_util.h | 55 + .../backends/llvm/runtime_symbol_registry.cc | 68 + .../backends/llvm/runtime_symbol_registry.h | 113 + paddle/cinn/backends/llvm/simple_jit.cc | 133 + paddle/cinn/backends/llvm/simple_jit.h | 82 + paddle/cinn/backends/modular.cc | 128 + paddle/cinn/backends/modular.h | 40 + paddle/cinn/backends/nvrtc/CMakeLists.txt | 8 + .../cinn/backends/nvrtc/header_generator.cc | 44 + paddle/cinn/backends/nvrtc/header_generator.h | 47 + paddle/cinn/backends/nvrtc/nvrtc_util.cc | 239 ++ paddle/cinn/backends/nvrtc/nvrtc_util.h | 92 + paddle/cinn/backends/nvrtc/nvrtc_util_test.cc | 90 + paddle/cinn/backends/outputs.cc | 50 + paddle/cinn/backends/outputs.h | 52 + paddle/cinn/backends/raw_cuda_code_test.cu | 54 + paddle/cinn/cinn.h | 56 + paddle/cinn/common/CMakeLists.txt | 36 + paddle/cinn/common/arithmatic.cc | 310 ++ paddle/cinn/common/arithmatic.h | 85 + paddle/cinn/common/arithmatic_test.cc | 92 + paddle/cinn/common/axis.cc | 115 + paddle/cinn/common/axis.h | 45 + paddle/cinn/common/axis_test.cc | 45 + paddle/cinn/common/bfloat16.h | 402 +++ paddle/cinn/common/cas.cc | 2200 ++++++++++++ paddle/cinn/common/cas.h | 166 + paddle/cinn/common/cas_test.cc | 432 +++ paddle/cinn/common/cinn_value.cc | 251 ++ paddle/cinn/common/cinn_value.h | 222 ++ paddle/cinn/common/cinn_value_test.cc | 59 + paddle/cinn/common/common.h | 69 + paddle/cinn/common/context.cc | 80 + paddle/cinn/common/context.h | 95 + paddle/cinn/common/cost_model.h | 40 + paddle/cinn/common/cuda_test_helper.cc | 96 + paddle/cinn/common/cuda_test_helper.h | 56 + paddle/cinn/common/debug_manager.cc | 68 + paddle/cinn/common/debug_manager.h | 50 + paddle/cinn/common/float16.h | 629 ++++ .../cinn/common/float16_bfloat16_cuda_test.cu | 236 ++ .../cinn/common/float16_bfloat16_host_test.cc | 98 + paddle/cinn/common/float16_bfloat16_utils.h | 183 + paddle/cinn/common/graph_utils.cc | 212 ++ paddle/cinn/common/graph_utils.h | 289 ++ paddle/cinn/common/graph_utils_test.cc | 92 + paddle/cinn/common/info_registry.cc | 19 + paddle/cinn/common/info_registry.h | 50 + paddle/cinn/common/ir_util.cc | 417 +++ paddle/cinn/common/ir_util.h | 146 + paddle/cinn/common/macros.h | 51 + paddle/cinn/common/object.cc | 19 + paddle/cinn/common/object.h | 78 + .../cinn/common/python_interpreter_guard.cc | 32 + paddle/cinn/common/python_interpreter_guard.h | 43 + paddle/cinn/common/shared.cc | 15 + paddle/cinn/common/shared.h | 151 + paddle/cinn/common/shared_test.cc | 55 + paddle/cinn/common/target.cc | 225 ++ paddle/cinn/common/target.h | 115 + paddle/cinn/common/test_helper.cc | 80 + paddle/cinn/common/test_helper.h | 115 + paddle/cinn/common/type.cc | 570 ++++ paddle/cinn/common/type.h | 316 ++ paddle/cinn/common/type_test.cc | 31 + paddle/cinn/common/union_find.cc | 24 + paddle/cinn/common/union_find.h | 100 + paddle/cinn/frontend/CMakeLists.txt | 49 + paddle/cinn/frontend/computation.cc | 243 ++ paddle/cinn/frontend/computation.h | 161 + paddle/cinn/frontend/computation_test.cc | 300 ++ .../cinn/frontend/decomposer/CMakeLists.txt | 19 + paddle/cinn/frontend/decomposer/activation.cc | 146 + .../frontend/decomposer/activation_test.cc | 99 + paddle/cinn/frontend/decomposer/batch_norm.cc | 302 ++ .../frontend/decomposer/batch_norm_test.cc | 420 +++ paddle/cinn/frontend/decomposer/broadcast.cc | 175 + .../frontend/decomposer/broadcast_test.cc | 281 ++ .../cinn/frontend/decomposer/elementwise.cc | 46 + .../frontend/decomposer/elementwise_test.cc | 45 + .../cinn/frontend/decomposer/test_helper.cc | 88 + paddle/cinn/frontend/decomposer/test_helper.h | 242 ++ paddle/cinn/frontend/decomposer/top_k.cc | 54 + paddle/cinn/frontend/decomposer/top_k_test.cc | 55 + .../cinn/frontend/decomposer/use_decomposer.h | 29 + paddle/cinn/frontend/decomposer_registry.h | 128 + .../cinn/frontend/decomposer_registry_test.cc | 29 + paddle/cinn/frontend/interpreter.cc | 142 + paddle/cinn/frontend/interpreter.h | 66 + paddle/cinn/frontend/interpreter_test.cc | 34 + paddle/cinn/frontend/net_builder.cc | 939 +++++ paddle/cinn/frontend/net_builder.h | 1146 +++++++ paddle/cinn/frontend/net_builder_test.cc | 1501 ++++++++ paddle/cinn/frontend/op_mapper_registry.cc | 89 + paddle/cinn/frontend/op_mapper_registry.h | 151 + .../cinn/frontend/op_mapper_registry_test.cc | 42 + .../cinn/frontend/op_mappers/CMakeLists.txt | 4 + .../cinn/frontend/op_mappers/common_utils.h | 171 + .../frontend/op_mappers/paddle/CMakeLists.txt | 3 + .../frontend/op_mappers/paddle/arg_min_max.cc | 92 + .../frontend/op_mappers/paddle/argsort.cc | 59 + .../cinn/frontend/op_mappers/paddle/atan.cc | 54 + .../frontend/op_mappers/paddle/batchnorm.cc | 168 + .../cinn/frontend/op_mappers/paddle/binary.cc | 64 + .../frontend/op_mappers/paddle/cholesky.cc | 45 + .../cinn/frontend/op_mappers/paddle/clip.cc | 81 + .../frontend/op_mappers/paddle/compare.cc | 79 + .../cinn/frontend/op_mappers/paddle/concat.cc | 173 + .../frontend/op_mappers/paddle/constant.cc | 238 ++ .../cinn/frontend/op_mappers/paddle/conv2d.cc | 178 + .../cinn/frontend/op_mappers/paddle/cumsum.cc | 96 + .../frontend/op_mappers/paddle/dropout.cc | 45 + .../frontend/op_mappers/paddle/elementwise.cc | 257 ++ .../cinn/frontend/op_mappers/paddle/expand.cc | 120 + .../frontend/op_mappers/paddle/fetch_feed.cc | 61 + .../cinn/frontend/op_mappers/paddle/flip.cc | 51 + .../cinn/frontend/op_mappers/paddle/gather.cc | 68 + .../frontend/op_mappers/paddle/gather_nd.cc | 49 + .../op_mappers/paddle/gaussian_random.cc | 52 + .../frontend/op_mappers/paddle/layer_norm.cc | 160 + paddle/cinn/frontend/op_mappers/paddle/log.cc | 88 + .../op_mappers/paddle/lookup_table.cc | 65 + .../cinn/frontend/op_mappers/paddle/matmul.cc | 57 + paddle/cinn/frontend/op_mappers/paddle/mul.cc | 51 + .../cinn/frontend/op_mappers/paddle/norm.cc | 110 + .../frontend/op_mappers/paddle/one_hot.cc | 75 + .../cinn/frontend/op_mappers/paddle/pool2d.cc | 115 + .../frontend/op_mappers/paddle/randint.cc | 56 + .../cinn/frontend/op_mappers/paddle/reduce.cc | 119 + .../cinn/frontend/op_mappers/paddle/relu.cc | 73 + .../frontend/op_mappers/paddle/reshape.cc | 135 + .../frontend/op_mappers/paddle/reverse.cc | 45 + .../cinn/frontend/op_mappers/paddle/roll.cc | 106 + .../cinn/frontend/op_mappers/paddle/scale.cc | 80 + .../frontend/op_mappers/paddle/scatter.cc | 75 + .../cinn/frontend/op_mappers/paddle/slice.cc | 52 + .../frontend/op_mappers/paddle/softmax.cc | 44 + .../frontend/op_mappers/paddle/squeeze.cc | 63 + .../op_mappers/paddle/strided_slice.cc | 55 + .../op_mappers/paddle/take_along_axis.cc | 52 + .../cinn/frontend/op_mappers/paddle/tile.cc | 98 + .../cinn/frontend/op_mappers/paddle/top_k.cc | 48 + .../frontend/op_mappers/paddle/transpose.cc | 82 + .../op_mappers/paddle/triangular_solve.cc | 52 + .../cinn/frontend/op_mappers/paddle/unary.cc | 109 + .../op_mappers/paddle/uniform_random.cc | 56 + .../frontend/op_mappers/paddle/unsqueeze.cc | 63 + .../cinn/frontend/op_mappers/paddle/where.cc | 50 + .../op_mappers/science/CMakeLists.txt | 3 + .../frontend/op_mappers/science/broadcast.cc | 77 + .../frontend/op_mappers/science/compare.cc | 64 + .../cinn/frontend/op_mappers/science/math.cc | 107 + .../frontend/op_mappers/science/transform.cc | 404 +++ .../cinn/frontend/op_mappers/use_op_mappers.h | 72 + paddle/cinn/frontend/optimize.cc | 177 + paddle/cinn/frontend/optimize.h | 49 + paddle/cinn/frontend/paddle/CMakeLists.txt | 26 + paddle/cinn/frontend/paddle/README.md | 1 + paddle/cinn/frontend/paddle/compatible_pb.cc | 266 ++ paddle/cinn/frontend/paddle/compatible_pb.h | 56 + .../cinn/frontend/paddle/cpp/CMakeLists.txt | 18 + paddle/cinn/frontend/paddle/cpp/block_desc.cc | 55 + paddle/cinn/frontend/paddle/cpp/block_desc.h | 79 + paddle/cinn/frontend/paddle/cpp/desc_api.h | 250 ++ paddle/cinn/frontend/paddle/cpp/op_desc.cc | 152 + paddle/cinn/frontend/paddle/cpp/op_desc.h | 110 + .../cinn/frontend/paddle/cpp/program_desc.cc | 37 + .../cinn/frontend/paddle/cpp/program_desc.h | 59 + paddle/cinn/frontend/paddle/cpp/var_desc.cc | 17 + paddle/cinn/frontend/paddle/cpp/var_desc.h | 58 + paddle/cinn/frontend/paddle/framework.proto | 214 ++ paddle/cinn/frontend/paddle/model_parser.cc | 272 ++ paddle/cinn/frontend/paddle/model_parser.h | 66 + .../cinn/frontend/paddle/model_parser_test.cc | 45 + paddle/cinn/frontend/paddle/pb/CMakeLists.txt | 16 + paddle/cinn/frontend/paddle/pb/block_desc.cc | 41 + paddle/cinn/frontend/paddle/pb/block_desc.h | 71 + paddle/cinn/frontend/paddle/pb/op_desc.cc | 124 + paddle/cinn/frontend/paddle/pb/op_desc.h | 169 + .../cinn/frontend/paddle/pb/program_desc.cc | 33 + paddle/cinn/frontend/paddle/pb/program_desc.h | 57 + paddle/cinn/frontend/paddle/pb/var_desc.cc | 341 ++ paddle/cinn/frontend/paddle/pb/var_desc.h | 115 + .../cinn/frontend/paddle_model_convertor.cc | 204 ++ paddle/cinn/frontend/paddle_model_convertor.h | 99 + .../frontend/paddle_model_convertor_test.cc | 108 + .../cinn/frontend/paddle_model_to_program.cc | 736 ++++ .../cinn/frontend/paddle_model_to_program.h | 141 + paddle/cinn/frontend/pass/CMakeLists.txt | 36 + paddle/cinn/frontend/pass/auto_broadcast.cc | 139 + paddle/cinn/frontend/pass/auto_cast.cc | 250 ++ paddle/cinn/frontend/pass/auto_cast_test.cc | 86 + paddle/cinn/frontend/pass/cast_collapsing.cc | 347 ++ .../frontend/pass/cast_collapsing_test.cc | 200 ++ .../cinn/frontend/pass/dead_code_eliminate.cc | 116 + .../frontend/pass/dead_code_eliminate_test.cc | 81 + paddle/cinn/frontend/pass/decomposer.cc | 85 + paddle/cinn/frontend/pass/decomposer_test.cc | 88 + .../frontend/pass/expand_zero_dim_pass.cc | 73 + .../pass/expand_zero_dim_pass_test.cc | 157 + .../frontend/pass/fill_constant_folding.cc | 191 ++ .../pass/fill_constant_folding_test.cc | 210 ++ .../frontend/pass/fill_constant_rewriter.cc | 226 ++ .../pass/fill_constant_rewriter_test.cc | 160 + paddle/cinn/frontend/pass/gemm_rewriter.cc | 216 ++ .../cinn/frontend/pass/gemm_rewriter_test.cc | 280 ++ paddle/cinn/frontend/pass/pass_test_helper.h | 212 ++ .../frontend/pass/program_topoerror_test.cc | 71 + paddle/cinn/frontend/pass/remove_identity.cc | 276 ++ .../frontend/pass/remove_identity_test.cc | 123 + paddle/cinn/frontend/pass/test_helper.h | 185 + .../frontend/pass/transpose_collapsing.cc | 393 +++ .../pass/transpose_collapsing_test.cc | 455 +++ .../frontend/pass/transpose_folding_base.h | 211 ++ .../frontend/pass/transpose_folding_input.cc | 161 + .../pass/transpose_folding_input_test.cc | 257 ++ .../frontend/pass/transpose_folding_output.cc | 111 + .../pass/transpose_folding_output_test.cc | 568 ++++ .../pass/transpose_scale_folding_test.cc | 370 ++ paddle/cinn/frontend/pass/use_program_pass.h | 31 + paddle/cinn/frontend/program_pass.cc | 46 + paddle/cinn/frontend/program_pass.h | 116 + paddle/cinn/frontend/syntax.cc | 568 ++++ paddle/cinn/frontend/syntax.h | 507 +++ paddle/cinn/frontend/syntax_test.cc | 143 + paddle/cinn/frontend/var_type_utils.h | 102 + paddle/cinn/gtest_main.cc | 23 + paddle/cinn/hlir/CMakeLists.txt | 5 + paddle/cinn/hlir/framework/CMakeLists.txt | 41 + .../cinn/hlir/framework/accuracy_checker.cc | 312 ++ paddle/cinn/hlir/framework/accuracy_checker.h | 52 + .../hlir/framework/accuracy_checker_test.cc | 162 + paddle/cinn/hlir/framework/buffer.cc | 96 + paddle/cinn/hlir/framework/buffer.h | 91 + paddle/cinn/hlir/framework/buffer_test.cc | 61 + paddle/cinn/hlir/framework/graph.cc | 514 +++ paddle/cinn/hlir/framework/graph.h | 232 ++ paddle/cinn/hlir/framework/graph_compiler.cc | 1532 +++++++++ paddle/cinn/hlir/framework/graph_compiler.h | 210 ++ .../hlir/framework/graph_compiler_test.cc | 216 ++ paddle/cinn/hlir/framework/graph_test.cc | 70 + paddle/cinn/hlir/framework/instruction.cc | 333 ++ paddle/cinn/hlir/framework/instruction.h | 150 + .../cinn/hlir/framework/instruction_test.cc | 482 +++ paddle/cinn/hlir/framework/memory.cc | 68 + paddle/cinn/hlir/framework/memory.h | 77 + paddle/cinn/hlir/framework/node.cc | 177 + paddle/cinn/hlir/framework/node.h | 212 ++ paddle/cinn/hlir/framework/op.h | 248 ++ paddle/cinn/hlir/framework/op_lowering.cc | 1351 ++++++++ paddle/cinn/hlir/framework/op_lowering.h | 100 + .../cinn/hlir/framework/op_lowering_test.cc | 1268 +++++++ .../cinn/hlir/framework/op_lowering_util.cc | 1661 +++++++++ paddle/cinn/hlir/framework/op_lowering_util.h | 102 + paddle/cinn/hlir/framework/op_strategy.cc | 55 + paddle/cinn/hlir/framework/op_strategy.h | 138 + paddle/cinn/hlir/framework/op_test.cc | 87 + .../cinn/hlir/framework/parallel_compiler.cc | 230 ++ .../cinn/hlir/framework/parallel_compiler.h | 97 + .../hlir/framework/parallel_compiler_test.cc | 83 + paddle/cinn/hlir/framework/pass.cc | 58 + paddle/cinn/hlir/framework/pass.h | 109 + .../hlir/framework/print_graph_pass_test.cc | 78 + paddle/cinn/hlir/framework/schedule.h | 64 + paddle/cinn/hlir/framework/scope.cc | 51 + paddle/cinn/hlir/framework/scope.h | 75 + paddle/cinn/hlir/framework/scope_test.cc | 45 + paddle/cinn/hlir/framework/tensor.cc | 58 + paddle/cinn/hlir/framework/tensor.h | 118 + paddle/cinn/hlir/framework/tensor_test.cc | 36 + paddle/cinn/hlir/framework/variable.cc | 21 + paddle/cinn/hlir/framework/variable.h | 21 + .../cinn/hlir/framework/visualize_helper.cc | 440 +++ paddle/cinn/hlir/framework/visualize_helper.h | 161 + paddle/cinn/hlir/kernels/CMakeLists.txt | 0 paddle/cinn/hlir/op/CMakeLists.txt | 23 + paddle/cinn/hlir/op/broadcast.cc | 455 +++ paddle/cinn/hlir/op/contrib/CMakeLists.txt | 29 + paddle/cinn/hlir/op/contrib/argmax.cc | 249 ++ paddle/cinn/hlir/op/contrib/argmax.h | 32 + paddle/cinn/hlir/op/contrib/argmax_test.cc | 119 + paddle/cinn/hlir/op/contrib/argmin.cc | 248 ++ paddle/cinn/hlir/op/contrib/argmin.h | 32 + paddle/cinn/hlir/op/contrib/argmin_test.cc | 118 + paddle/cinn/hlir/op/contrib/assert_true.cc | 89 + .../cinn/hlir/op/contrib/bitcast_convert.cc | 133 + paddle/cinn/hlir/op/contrib/cholesky.cc | 110 + paddle/cinn/hlir/op/contrib/gather_nd.cc | 193 ++ paddle/cinn/hlir/op/contrib/gather_nd.h | 32 + paddle/cinn/hlir/op/contrib/gather_nd_test.cc | 95 + .../cinn/hlir/op/contrib/gaussian_random.cc | 111 + .../hlir/op/contrib/logical_right_shift.cc | 157 + .../hlir/op/contrib/logical_right_shift.h | 35 + .../op/contrib/logical_right_shift_test.cc | 64 + paddle/cinn/hlir/op/contrib/lookup_table.cc | 148 + paddle/cinn/hlir/op/contrib/lookup_table.h | 37 + .../cinn/hlir/op/contrib/lookup_table_test.cc | 94 + paddle/cinn/hlir/op/contrib/one_hot.cc | 225 ++ paddle/cinn/hlir/op/contrib/one_hot.h | 38 + paddle/cinn/hlir/op/contrib/one_hot_test.cc | 107 + paddle/cinn/hlir/op/contrib/randint.cc | 104 + paddle/cinn/hlir/op/contrib/reciprocal.cc | 156 + paddle/cinn/hlir/op/contrib/reciprocal.h | 29 + .../cinn/hlir/op/contrib/reciprocal_test.cc | 68 + paddle/cinn/hlir/op/contrib/repeat.cc | 230 ++ paddle/cinn/hlir/op/contrib/repeat.h | 32 + paddle/cinn/hlir/op/contrib/repeat_test.cc | 116 + paddle/cinn/hlir/op/contrib/resize.cc | 241 ++ paddle/cinn/hlir/op/contrib/resize.h | 36 + paddle/cinn/hlir/op/contrib/sort.cc | 412 +++ paddle/cinn/hlir/op/contrib/sort.h | 44 + paddle/cinn/hlir/op/contrib/sort_test.cc | 133 + .../cinn/hlir/op/contrib/triangular_solve.cc | 121 + paddle/cinn/hlir/op/contrib/uniform_random.cc | 111 + paddle/cinn/hlir/op/custom_call.cc | 853 +++++ paddle/cinn/hlir/op/elementwise.cc | 1056 ++++++ paddle/cinn/hlir/op/external_api_registry.cc | 86 + paddle/cinn/hlir/op/external_api_registry.h | 78 + .../hlir/op/external_api_registry_test.cc | 50 + paddle/cinn/hlir/op/nn.cc | 2460 ++++++++++++++ paddle/cinn/hlir/op/op_broadcast_test.cc | 318 ++ paddle/cinn/hlir/op/op_nn_test.cc | 513 +++ paddle/cinn/hlir/op/op_util.cc | 169 + paddle/cinn/hlir/op/op_util.h | 140 + paddle/cinn/hlir/op/reduction.cc | 505 +++ paddle/cinn/hlir/op/reduction_test.cc | 561 +++ paddle/cinn/hlir/op/transform.cc | 1797 ++++++++++ paddle/cinn/hlir/op/transform_test.cc | 121 + paddle/cinn/hlir/op/use_ops.h | 43 + paddle/cinn/hlir/pass/CMakeLists.txt | 42 + paddle/cinn/hlir/pass/alterlayout.cc | 649 ++++ paddle/cinn/hlir/pass/alterlayout_test.cc | 458 +++ .../hlir/pass/check_fusion_accuracy_pass.cc | 546 +++ .../pass/check_fusion_accuracy_pass_test.cc | 589 ++++ .../pass/common_subexpression_elimination.cc | 307 ++ .../common_subexpression_elimination_test.cc | 198 ++ paddle/cinn/hlir/pass/const_propagate.cc | 74 + paddle/cinn/hlir/pass/const_propagate_test.cc | 131 + .../cinn/hlir/pass/constant_folding_pass.cc | 122 + .../hlir/pass/constant_folding_pass_test.cc | 333 ++ .../hlir/pass/constant_folding_pass_util.cc | 237 ++ .../hlir/pass/constant_folding_pass_util.h | 39 + paddle/cinn/hlir/pass/custom_call_pass.cc | 100 + paddle/cinn/hlir/pass/dce_pass.cc | 135 + paddle/cinn/hlir/pass/dce_pass_test.cc | 64 + paddle/cinn/hlir/pass/dense_merge_pass.cc | 187 + .../cinn/hlir/pass/dense_merge_pass_test.cc | 168 + paddle/cinn/hlir/pass/dot_merger.cc | 437 +++ paddle/cinn/hlir/pass/dot_merger_test.cc | 112 + paddle/cinn/hlir/pass/fusion_helper_base.h | 208 ++ paddle/cinn/hlir/pass/fusion_merge_pass.cc | 1032 ++++++ .../cinn/hlir/pass/fusion_merge_pass_test.cc | 487 +++ .../cinn/hlir/pass/fusion_merge_pass_util.h | 561 +++ paddle/cinn/hlir/pass/infershape.cc | 128 + paddle/cinn/hlir/pass/infershape.h | 29 + paddle/cinn/hlir/pass/op_fusion_pass.cc | 384 +++ paddle/cinn/hlir/pass/op_fusion_pass_test.cc | 276 ++ paddle/cinn/hlir/pass/op_fusion_pass_util.h | 337 ++ paddle/cinn/hlir/pass/opfusion.cc | 536 +++ paddle/cinn/hlir/pass/opfusion_test.cc | 540 +++ paddle/cinn/hlir/pass/reduce_split_pass.cc | 230 ++ .../cinn/hlir/pass/reduce_split_pass_test.cc | 95 + .../hlir/pass/single_group_optimize_pass.cc | 201 ++ paddle/cinn/hlir/pass/test_dot_merger.cc | 100 + paddle/cinn/hlir/pass/test_primitive_ops.cc | 153 + paddle/cinn/hlir/pass/use_pass.h | 35 + paddle/cinn/hlir/pe/CMakeLists.txt | 25 + paddle/cinn/hlir/pe/broadcast.cc | 372 ++ paddle/cinn/hlir/pe/broadcast.h | 126 + paddle/cinn/hlir/pe/elementwise.cc | 233 ++ paddle/cinn/hlir/pe/elementwise.h | 129 + paddle/cinn/hlir/pe/ir_schedule_pe.cc | 1223 +++++++ paddle/cinn/hlir/pe/ir_schedule_pe.h | 102 + paddle/cinn/hlir/pe/load_params_test.cc | 62 + paddle/cinn/hlir/pe/load_x86_params.cc | 1308 +++++++ paddle/cinn/hlir/pe/load_x86_params.h | 50 + paddle/cinn/hlir/pe/nn.cc | 1290 +++++++ paddle/cinn/hlir/pe/nn.h | 426 +++ paddle/cinn/hlir/pe/nn_util.cc | 428 +++ paddle/cinn/hlir/pe/nn_util.h | 46 + paddle/cinn/hlir/pe/pe_broadcast_test.cc | 220 ++ paddle/cinn/hlir/pe/pe_elementwise_test.cc | 160 + paddle/cinn/hlir/pe/pe_transform_test.cc | 229 ++ paddle/cinn/hlir/pe/reduction.cc | 884 +++++ paddle/cinn/hlir/pe/reduction.h | 419 +++ paddle/cinn/hlir/pe/schedule.cc | 2270 +++++++++++++ paddle/cinn/hlir/pe/schedule.h | 248 ++ paddle/cinn/hlir/pe/schedule_param.proto | 29 + paddle/cinn/hlir/pe/transform.cc | 1182 +++++++ paddle/cinn/hlir/pe/transform.h | 232 ++ paddle/cinn/hlir/pe/vision.cc | 21 + paddle/cinn/hlir/pe/vision.h | 21 + paddle/cinn/ir/CMakeLists.txt | 44 + paddle/cinn/ir/buffer.cc | 169 + paddle/cinn/ir/buffer.h | 192 ++ paddle/cinn/ir/buffer_test.cc | 86 + paddle/cinn/ir/collect_ir_nodes.cc | 186 + paddle/cinn/ir/collect_ir_nodes.h | 56 + paddle/cinn/ir/collect_ir_nodes_test.cc | 58 + paddle/cinn/ir/function_base.cc | 19 + paddle/cinn/ir/function_base.h | 34 + paddle/cinn/ir/function_definition.cc | 19 + paddle/cinn/ir/function_definition.h | 43 + paddle/cinn/ir/intrinsic_ops.cc | 127 + paddle/cinn/ir/intrinsic_ops.h | 200 ++ paddle/cinn/ir/intrinsic_ops_test.cc | 31 + paddle/cinn/ir/ir.cc | 819 +++++ paddle/cinn/ir/ir.h | 999 ++++++ paddle/cinn/ir/ir_base.cc | 231 ++ paddle/cinn/ir/ir_base.h | 500 +++ paddle/cinn/ir/ir_compare.cc | 319 ++ paddle/cinn/ir/ir_compare.h | 46 + paddle/cinn/ir/ir_compare_test.cc | 124 + paddle/cinn/ir/ir_mutator.cc | 22 + paddle/cinn/ir/ir_mutator.h | 334 ++ paddle/cinn/ir/ir_operators.cc | 153 + paddle/cinn/ir/ir_operators.h | 133 + paddle/cinn/ir/ir_operators_test.cc | 28 + paddle/cinn/ir/ir_printer.cc | 645 ++++ paddle/cinn/ir/ir_printer.h | 80 + paddle/cinn/ir/ir_printer_test.cc | 23 + paddle/cinn/ir/ir_schedule.cc | 2310 +++++++++++++ paddle/cinn/ir/ir_schedule.h | 614 ++++ paddle/cinn/ir/ir_schedule_util.cc | 1038 ++++++ paddle/cinn/ir/ir_schedule_util.h | 448 +++ paddle/cinn/ir/ir_test.cc | 31 + paddle/cinn/ir/ir_verify.cc | 39 + paddle/cinn/ir/ir_verify.h | 22 + paddle/cinn/ir/ir_verify_test.cc | 29 + paddle/cinn/ir/ir_visitor.cc | 35 + paddle/cinn/ir/ir_visitor.h | 82 + paddle/cinn/ir/layout.cc | 67 + paddle/cinn/ir/layout.h | 48 + paddle/cinn/ir/lowered_func.cc | 472 +++ paddle/cinn/ir/lowered_func.h | 198 ++ paddle/cinn/ir/module.cc | 97 + paddle/cinn/ir/module.h | 89 + paddle/cinn/ir/operation.cc | 113 + paddle/cinn/ir/operation.h | 130 + paddle/cinn/ir/registry.cc | 93 + paddle/cinn/ir/registry.h | 46 + paddle/cinn/ir/schedule_desc.cc | 680 ++++ paddle/cinn/ir/schedule_desc.h | 106 + paddle/cinn/ir/schedule_desc.proto | 67 + paddle/cinn/ir/schedule_desc_test.cc | 809 +++++ paddle/cinn/ir/tensor.cc | 590 ++++ paddle/cinn/ir/tensor.h | 342 ++ paddle/cinn/ir/tensor_test.cc | 211 ++ paddle/cinn/lang/CMakeLists.txt | 17 + paddle/cinn/lang/README.md | 93 + paddle/cinn/lang/buffer.cc | 36 + paddle/cinn/lang/buffer.h | 44 + paddle/cinn/lang/builtin.cc | 262 ++ paddle/cinn/lang/builtin.h | 173 + paddle/cinn/lang/compute.cc | 229 ++ paddle/cinn/lang/compute.h | 132 + paddle/cinn/lang/compute_test.cc | 39 + paddle/cinn/lang/lower.cc | 302 ++ paddle/cinn/lang/lower.h | 85 + paddle/cinn/lang/lower_impl.cc | 791 +++++ paddle/cinn/lang/lower_impl.h | 304 ++ paddle/cinn/lang/lower_impl_test.cc | 320 ++ paddle/cinn/lang/lower_test.cc | 155 + paddle/cinn/lang/packed_func.cc | 27 + paddle/cinn/lang/packed_func.h | 128 + paddle/cinn/lang/packed_func_test.cc | 95 + paddle/cinn/lang/placeholder.cc | 65 + paddle/cinn/lang/placeholder.h | 115 + paddle/cinn/lang/placeholder_test.cc | 48 + paddle/cinn/optim/CMakeLists.txt | 50 + paddle/cinn/optim/buffer_assign.cc | 156 + paddle/cinn/optim/buffer_assign.h | 39 + .../optim/cache_read_write_replace_test.cc | 125 + .../cinn/optim/call_arg_list_to_pod_value.cc | 108 + .../cinn/optim/call_arg_list_to_pod_value.h | 28 + paddle/cinn/optim/cast_bool_to_int8.cc | 47 + paddle/cinn/optim/cast_bool_to_int8.h | 34 + paddle/cinn/optim/cast_simplify.cc | 117 + paddle/cinn/optim/cast_simplify.h | 30 + paddle/cinn/optim/cast_simplify_test.cc | 63 + paddle/cinn/optim/collect_undefined_vars.cc | 109 + paddle/cinn/optim/collect_undefined_vars.h | 36 + paddle/cinn/optim/compute_inline_expand.cc | 233 ++ paddle/cinn/optim/compute_inline_expand.h | 33 + .../optim/eliminate_broadcast_in_forloop.cc | 111 + .../optim/eliminate_broadcast_in_forloop.h | 24 + paddle/cinn/optim/extern_call_process.cc | 41 + paddle/cinn/optim/extern_call_process.h | 27 + paddle/cinn/optim/fold_cinn_call_arguments.cc | 114 + paddle/cinn/optim/fold_cinn_call_arguments.h | 46 + paddle/cinn/optim/if_simplify.cc | 57 + paddle/cinn/optim/if_simplify.h | 22 + paddle/cinn/optim/if_simplify_test.cc | 70 + paddle/cinn/optim/insert_debug_log_callee.cc | 275 ++ paddle/cinn/optim/insert_debug_log_callee.h | 27 + paddle/cinn/optim/ir_copy.cc | 480 +++ paddle/cinn/optim/ir_copy.h | 43 + paddle/cinn/optim/ir_copy_test.cc | 31 + paddle/cinn/optim/ir_replace.cc | 64 + paddle/cinn/optim/ir_replace.h | 27 + paddle/cinn/optim/ir_simplify.cc | 365 ++ paddle/cinn/optim/ir_simplify.h | 37 + paddle/cinn/optim/ir_simplify_test.cc | 127 + .../optim/lower_function_call_bind_vars.cc | 73 + .../optim/lower_function_call_bind_vars.h | 26 + paddle/cinn/optim/lower_intrin.cc | 95 + paddle/cinn/optim/lower_intrin.h | 41 + paddle/cinn/optim/map_extern_call.cc | 119 + paddle/cinn/optim/map_extern_call.h | 33 + paddle/cinn/optim/optimize.cc | 111 + paddle/cinn/optim/optimize.h | 36 + paddle/cinn/optim/optimize_test.cc | 58 + paddle/cinn/optim/remove_nested_block.cc | 121 + paddle/cinn/optim/remove_nested_block.h | 33 + paddle/cinn/optim/remove_nested_block_test.cc | 58 + paddle/cinn/optim/remove_schedule_block.cc | 50 + paddle/cinn/optim/remove_schedule_block.h | 33 + .../cinn/optim/remove_schedule_block_test.cc | 98 + paddle/cinn/optim/replace_call_with_expr.cc | 125 + paddle/cinn/optim/replace_call_with_expr.h | 45 + .../cinn/optim/replace_call_with_expr_test.cc | 31 + .../optim/replace_const_param_to_integer.cc | 43 + .../optim/replace_const_param_to_integer.h | 34 + paddle/cinn/optim/replace_var_with_expr.cc | 159 + paddle/cinn/optim/replace_var_with_expr.h | 77 + paddle/cinn/optim/tensor_write_tell.cc | 19 + paddle/cinn/optim/tensor_write_tell.h | 54 + paddle/cinn/optim/transform_gpu_forloop.cc | 664 ++++ paddle/cinn/optim/transform_gpu_forloop.h | 65 + paddle/cinn/optim/transform_polyfor_to_for.cc | 136 + paddle/cinn/optim/transform_polyfor_to_for.h | 32 + .../optim/transform_polyfor_to_for_test.cc | 109 + paddle/cinn/optim/unroll_loops.cc | 118 + paddle/cinn/optim/unroll_loops.h | 24 + paddle/cinn/optim/unroll_loops_test.cc | 101 + paddle/cinn/optim/var_mod_simplify.cc | 91 + paddle/cinn/optim/var_mod_simplify.h | 32 + paddle/cinn/optim/vectorize_loops.cc | 890 +++++ paddle/cinn/optim/vectorize_loops.h | 37 + paddle/cinn/optim/vectorize_loops_test.cc | 288 ++ paddle/cinn/poly/CMakeLists.txt | 24 + paddle/cinn/poly/ast_gen.cc | 566 +++ paddle/cinn/poly/ast_gen.h | 99 + paddle/cinn/poly/ast_gen_test.cc | 130 + paddle/cinn/poly/compute_at_transform.cc | 244 ++ paddle/cinn/poly/compute_at_transform.h | 116 + paddle/cinn/poly/compute_at_transform_test.cc | 53 + paddle/cinn/poly/dim.cc | 36 + paddle/cinn/poly/dim.h | 68 + paddle/cinn/poly/domain.cc | 76 + paddle/cinn/poly/domain.h | 54 + .../cinn/poly/domain_add_unit_loop_mutator.cc | 217 ++ .../cinn/poly/domain_add_unit_loop_mutator.h | 51 + paddle/cinn/poly/graph.cc | 129 + paddle/cinn/poly/graph.h | 96 + paddle/cinn/poly/graph_test.cc | 24 + paddle/cinn/poly/isl_utils.cc | 512 +++ paddle/cinn/poly/isl_utils.h | 143 + paddle/cinn/poly/isl_utils_test.cc | 39 + paddle/cinn/poly/map.cc | 101 + paddle/cinn/poly/map.h | 108 + paddle/cinn/poly/naive_scheduler.cc | 53 + paddle/cinn/poly/naive_scheduler.h | 60 + paddle/cinn/poly/poly_scheduler.cc | 462 +++ paddle/cinn/poly/poly_scheduler.h | 87 + paddle/cinn/poly/poly_scheduler_test.cc | 21 + paddle/cinn/poly/schedule.cc | 254 ++ paddle/cinn/poly/schedule.h | 228 ++ paddle/cinn/poly/schedule_test.cc | 125 + paddle/cinn/poly/stage.cc | 1666 +++++++++ paddle/cinn/poly/stage.h | 537 +++ paddle/cinn/poly/stage_test.cc | 554 +++ paddle/cinn/pybind/CMakeLists.txt | 29 + paddle/cinn/pybind/backends.cc | 81 + paddle/cinn/pybind/bind.cc | 52 + paddle/cinn/pybind/bind.h | 52 + paddle/cinn/pybind/bind_utils.h | 168 + paddle/cinn/pybind/common.cc | 322 ++ paddle/cinn/pybind/framework.cc | 196 ++ paddle/cinn/pybind/frontend.cc | 799 +++++ paddle/cinn/pybind/ir.cc | 636 ++++ paddle/cinn/pybind/lang.cc | 248 ++ paddle/cinn/pybind/optim.cc | 52 + paddle/cinn/pybind/pe.cc | 135 + paddle/cinn/pybind/poly.cc | 124 + paddle/cinn/pybind/runtime.cc | 279 ++ paddle/cinn/pybind/utils.cc | 70 + paddle/cinn/runtime/CMakeLists.txt | 23 + paddle/cinn/runtime/buffer.cc | 52 + paddle/cinn/runtime/buffer.h | 100 + paddle/cinn/runtime/cinn_runtime.cc | 495 +++ paddle/cinn/runtime/cinn_runtime.h | 570 ++++ paddle/cinn/runtime/cinn_runtime_test.cc | 49 + paddle/cinn/runtime/cinn_x86_device_impl.cc | 85 + paddle/cinn/runtime/cpu/CMakeLists.txt | 26 + paddle/cinn/runtime/cpu/cblas.cc | 226 ++ paddle/cinn/runtime/cpu/cblas.h | 102 + paddle/cinn/runtime/cpu/host_intrinsics.cc | 460 +++ paddle/cinn/runtime/cpu/host_intrinsics.h | 122 + .../cinn/runtime/cpu/host_intrinsics_test.cc | 208 ++ paddle/cinn/runtime/cpu/mkl_math.cc | 105 + paddle/cinn/runtime/cpu/mkl_math.h | 53 + paddle/cinn/runtime/cpu/mkl_math_test.cc | 217 ++ paddle/cinn/runtime/cpu/mkldnn_math.cc | 204 ++ paddle/cinn/runtime/cpu/mkldnn_math.h | 45 + paddle/cinn/runtime/cpu/mkldnn_math_test.cc | 123 + paddle/cinn/runtime/cpu/thread_backend.cc | 69 + paddle/cinn/runtime/cpu/thread_backend.h | 46 + paddle/cinn/runtime/cpu/use_extern_funcs.h | 27 + paddle/cinn/runtime/cuda/CMakeLists.txt | 19 + paddle/cinn/runtime/cuda/bfloat16.h | 402 +++ .../runtime/cuda/cinn_cuda_runtime_source.cuh | 865 +++++ paddle/cinn/runtime/cuda/cublas_util.h | 328 ++ .../runtime/cuda/cuda_instrinsics_bfloat16.cc | 80 + .../runtime/cuda/cuda_instrinsics_float16.cc | 124 + paddle/cinn/runtime/cuda/cuda_intrinsics.cc | 733 ++++ .../runtime/cuda/cuda_intrinsics_reduce.cc | 156 + paddle/cinn/runtime/cuda/cuda_module.cc | 151 + paddle/cinn/runtime/cuda/cuda_module.h | 80 + paddle/cinn/runtime/cuda/cuda_module_test.cc | 179 + paddle/cinn/runtime/cuda/cuda_util.cc | 2277 +++++++++++++ paddle/cinn/runtime/cuda/cuda_util.h | 309 ++ paddle/cinn/runtime/cuda/float16.h | 629 ++++ paddle/cinn/runtime/cuda/test_util.h | 56 + paddle/cinn/runtime/cuda/use_extern_funcs.h | 24 + paddle/cinn/runtime/custom_function.cc | 199 ++ paddle/cinn/runtime/custom_function.h | 66 + paddle/cinn/runtime/custom_function_test.cc | 352 ++ paddle/cinn/runtime/flags.cc | 227 ++ paddle/cinn/runtime/flags.h | 62 + paddle/cinn/runtime/intrinsic.cc | 67 + paddle/cinn/runtime/intrinsic.h | 136 + paddle/cinn/runtime/intrinsic_types.cc | 29 + paddle/cinn/runtime/intrinsic_types.h | 52 + paddle/cinn/runtime/tiny_runtime.cc | 158 + paddle/cinn/runtime/use_extern_funcs.h | 20 + paddle/cinn/utils/CMakeLists.txt | 23 + paddle/cinn/utils/data_util.cc | 121 + paddle/cinn/utils/data_util.h | 45 + paddle/cinn/utils/dot_lang.cc | 159 + paddle/cinn/utils/dot_lang.h | 135 + paddle/cinn/utils/error.cc | 17 + paddle/cinn/utils/error.h | 24 + paddle/cinn/utils/event.cc | 125 + paddle/cinn/utils/event.h | 113 + paddle/cinn/utils/functional.cc | 40 + paddle/cinn/utils/functional.h | 127 + paddle/cinn/utils/functional_test.cc | 119 + paddle/cinn/utils/multi_threading.cc | 96 + paddle/cinn/utils/multi_threading.h | 62 + paddle/cinn/utils/multi_threading_test.cc | 59 + paddle/cinn/utils/profiler.cc | 126 + paddle/cinn/utils/profiler.h | 88 + paddle/cinn/utils/profiler_test.cc | 78 + paddle/cinn/utils/random_engine.cc | 41 + paddle/cinn/utils/random_engine.h | 109 + paddle/cinn/utils/registry.h | 210 ++ paddle/cinn/utils/sized_multi_set.cc | 17 + paddle/cinn/utils/sized_multi_set.h | 82 + paddle/cinn/utils/sized_multi_set_test.cc | 82 + paddle/cinn/utils/small_vector.cc | 17 + paddle/cinn/utils/small_vector.h | 23 + paddle/cinn/utils/string.cc | 167 + paddle/cinn/utils/string.h | 94 + paddle/cinn/utils/string_test.cc | 38 + paddle/cinn/utils/timer.cc | 31 + paddle/cinn/utils/timer.h | 35 + paddle/cinn/utils/type_defs.h | 46 + python/cinn/__init__.py | 28 + python/cinn/auto_schedule/__init__.py | 13 + .../cinn/auto_schedule/cost_model/__init__.py | 23 + .../auto_schedule/cost_model/cost_model.py | 82 + .../cost_model/xgb_cost_model.py | 98 + python/cinn/backends.py | 16 + python/cinn/common.py | 20 + python/cinn/framework.py | 15 + python/cinn/frontend.py | 15 + python/cinn/ir/__init__.py | 42 + python/cinn/lang.py | 18 + python/cinn/libs/__init__.py | 15 + python/cinn/optim.py | 16 + python/cinn/pe.py | 15 + python/cinn/poly.py | 15 + python/cinn/runtime.py | 15 + python/cinn/utils.py | 15 + python/cinn/version/__init__.py | 18 + .../cost_model/test_cost_model.py | 58 + test/cinn/conv2d_utils.py | 95 + test/cinn/fake_model/naive_mul.py | 43 + test/cinn/fake_model/naive_multi_fc.py | 58 + test/cinn/fake_model/resnet_model.py | 51 + test/cinn/fusion/fusion_test.py | 62 + .../fusion/test_cast_broadcast_reduce_max.py | 68 + test/cinn/fusion/test_reduce_cast.py | 39 + test/cinn/fusion/test_select_reduce.py | 51 + test/cinn/op_mappers/op_mapper_test.py | 423 +++ test/cinn/op_mappers/test_argmax_op.py | 113 + test/cinn/op_mappers/test_argmin_op.py | 113 + test/cinn/op_mappers/test_argsort_op.py | 55 + test/cinn/op_mappers/test_assign_value_op.py | 117 + test/cinn/op_mappers/test_atan2_op.py | 103 + test/cinn/op_mappers/test_batch_norm_op.py | 112 + test/cinn/op_mappers/test_bitwise_op.py | 85 + test/cinn/op_mappers/test_cholesky_op.py | 92 + test/cinn/op_mappers/test_clip_op.py | 256 ++ test/cinn/op_mappers/test_compare_op.py | 90 + test/cinn/op_mappers/test_conv2d_op.py | 93 + test/cinn/op_mappers/test_cumsum_op.py | 112 + test/cinn/op_mappers/test_elementwise_op.py | 117 + test/cinn/op_mappers/test_expand_op.py | 50 + test/cinn/op_mappers/test_expand_v2_op.py | 50 + test/cinn/op_mappers/test_fill_constant_op.py | 89 + test/cinn/op_mappers/test_flip_op.py | 52 + test/cinn/op_mappers/test_gather_nd_op.py | 63 + test/cinn/op_mappers/test_gather_op.py | 65 + .../op_mappers/test_gaussian_random_op.py | 79 + test/cinn/op_mappers/test_layer_norm_op.py | 79 + test/cinn/op_mappers/test_log1p_op.py | 50 + test/cinn/op_mappers/test_logical_op.py | 82 + test/cinn/op_mappers/test_lookup_table_op.py | 106 + test/cinn/op_mappers/test_mul_op.py | 80 + test/cinn/op_mappers/test_norm_op.py | 70 + test/cinn/op_mappers/test_one_hot_op.py | 116 + test/cinn/op_mappers/test_pool2d_op.py | 131 + test/cinn/op_mappers/test_pow_op.py | 83 + test/cinn/op_mappers/test_randint_op.py | 81 + test/cinn/op_mappers/test_reduce_op.py | 141 + test/cinn/op_mappers/test_reverse_op.py | 52 + test/cinn/op_mappers/test_roll_op.py | 111 + test/cinn/op_mappers/test_scale_op.py | 115 + test/cinn/op_mappers/test_scatter_op.py | 87 + test/cinn/op_mappers/test_sign_op.py | 50 + test/cinn/op_mappers/test_split_op.py | 94 + test/cinn/op_mappers/test_squeeze_op.py | 62 + test/cinn/op_mappers/test_stack_op.py | 54 + test/cinn/op_mappers/test_strided_slice_op.py | 121 + .../op_mappers/test_take_along_axis_op.py | 150 + test/cinn/op_mappers/test_tile_op.py | 91 + test/cinn/op_mappers/test_transpose2_op.py | 71 + .../op_mappers/test_triangular_solve_op.py | 69 + test/cinn/op_mappers/test_unary_op.py | 242 ++ .../cinn/op_mappers/test_uniform_random_op.py | 91 + test/cinn/op_mappers/test_where_op.py | 60 + test/cinn/ops/op_test.py | 320 ++ test/cinn/ops/op_test_helper.py | 134 + test/cinn/ops/test_abs_op.py | 111 + test/cinn/ops/test_acos_op.py | 104 + test/cinn/ops/test_add_op.py | 257 ++ test/cinn/ops/test_arange_op.py | 191 ++ test/cinn/ops/test_argsort_op.py | 112 + test/cinn/ops/test_asin_op.py | 109 + test/cinn/ops/test_asinh_op.py | 101 + test/cinn/ops/test_atan2_op.py | 139 + test/cinn/ops/test_atan_op.py | 101 + test/cinn/ops/test_atanh_op.py | 101 + test/cinn/ops/test_batch_norm_op.py | 245 ++ test/cinn/ops/test_binary_elementwise_op.py | 370 ++ test/cinn/ops/test_bitcast_convert_op.py | 103 + test/cinn/ops/test_bitwise_op.py | 256 ++ test/cinn/ops/test_broadcast_to_op.py | 174 + test/cinn/ops/test_broadcast_to_op_new.py | 224 ++ test/cinn/ops/test_cast_op.py | 155 + test/cinn/ops/test_cbrt_op.py | 143 + test/cinn/ops/test_ceil_op.py | 144 + test/cinn/ops/test_cholesky_op.py | 231 ++ test/cinn/ops/test_clz_op.py | 140 + test/cinn/ops/test_comparison_op.py | 357 ++ test/cinn/ops/test_concat_op.py | 362 ++ test/cinn/ops/test_constant_op.py | 163 + test/cinn/ops/test_conv2d_op.py | 210 ++ test/cinn/ops/test_cos_op.py | 106 + test/cinn/ops/test_cosh_op.py | 106 + test/cinn/ops/test_depthwise_conv2d_op.py | 192 ++ test/cinn/ops/test_divide_op.py | 278 ++ test/cinn/ops/test_dropout_infer_op.py | 121 + test/cinn/ops/test_erf_op.py | 106 + test/cinn/ops/test_exp_op.py | 104 + test/cinn/ops/test_expand_dims.py | 145 + test/cinn/ops/test_fill_constant_op.py | 259 ++ test/cinn/ops/test_floor_divide_op.py | 232 ++ test/cinn/ops/test_floor_op.py | 109 + test/cinn/ops/test_gather_nd_op.py | 99 + test/cinn/ops/test_gather_op.py | 157 + test/cinn/ops/test_gaussian_random_op.py | 158 + test/cinn/ops/test_gelu_op.py | 116 + test/cinn/ops/test_identity_op.py | 114 + test/cinn/ops/test_is_finite_op.py | 116 + test/cinn/ops/test_is_inf_op.py | 116 + test/cinn/ops/test_is_nan_op.py | 116 + test/cinn/ops/test_isclose_op.py | 205 ++ test/cinn/ops/test_left_shift_op.py | 150 + test/cinn/ops/test_log_op.py | 145 + test/cinn/ops/test_logical_right_shift_op.py | 130 + test/cinn/ops/test_lookup_table_op.py | 112 + test/cinn/ops/test_matmul_op.py | 259 ++ test/cinn/ops/test_max_op.py | 108 + test/cinn/ops/test_mod_op.py | 132 + test/cinn/ops/test_mul_op.py | 64 + test/cinn/ops/test_multiply_op.py | 89 + test/cinn/ops/test_negative_op.py | 121 + test/cinn/ops/test_one_hot_op.py | 71 + test/cinn/ops/test_pool2d_op.py | 339 ++ test/cinn/ops/test_popc_op.py | 139 + test/cinn/ops/test_pow_op.py | 152 + test/cinn/ops/test_randint_op.py | 82 + test/cinn/ops/test_reciprocal_op.py | 98 + test/cinn/ops/test_reduce_op.py | 680 ++++ test/cinn/ops/test_reduce_op_new.py | 216 ++ test/cinn/ops/test_reduce_op_other.py | 87 + test/cinn/ops/test_relu6_op.py | 94 + test/cinn/ops/test_relu_op.py | 155 + test/cinn/ops/test_remainder_op.py | 198 ++ test/cinn/ops/test_repeat_op.py | 267 ++ test/cinn/ops/test_reshape_op.py | 223 ++ test/cinn/ops/test_resize_op.py | 115 + test/cinn/ops/test_reverse_op.py | 311 ++ test/cinn/ops/test_right_shift_op.py | 150 + test/cinn/ops/test_round_op.py | 112 + test/cinn/ops/test_rsqrt_op.py | 109 + test/cinn/ops/test_scale_op.py | 169 + test/cinn/ops/test_scatter_add.py | 351 ++ test/cinn/ops/test_scatter_assign_op.py | 263 ++ test/cinn/ops/test_select_op.py | 153 + test/cinn/ops/test_sigmoid_op.py | 144 + test/cinn/ops/test_sign_op.py | 146 + test/cinn/ops/test_sin_op.py | 106 + test/cinn/ops/test_sinh_op.py | 106 + test/cinn/ops/test_slice_assign_op.py | 421 +++ test/cinn/ops/test_slice_op.py | 378 +++ test/cinn/ops/test_softmax_op.py | 97 + test/cinn/ops/test_sort_op.py | 228 ++ test/cinn/ops/test_split_op.py | 368 ++ test/cinn/ops/test_sqrt_op.py | 107 + test/cinn/ops/test_squeeze_op.py | 213 ++ test/cinn/ops/test_subtract_op.py | 250 ++ test/cinn/ops/test_sum_op.py | 163 + test/cinn/ops/test_tan_op.py | 106 + test/cinn/ops/test_tanh_op.py | 106 + test/cinn/ops/test_top_k_op.py | 307 ++ test/cinn/ops/test_transpose_op.py | 272 ++ test/cinn/ops/test_triangular_solve_op.py | 392 +++ test/cinn/ops/test_trunc_op.py | 111 + test/cinn/ops/test_unary_elementwise_op.py | 409 +++ test/cinn/ops/test_uniform_random_op.py | 159 + test/cinn/ops/test_zero_dim_tensor.py | 642 ++++ test/cinn/passes/pass_test.py | 104 + test/cinn/passes/test_auto_cast_pass.py | 40 + test/cinn/passes/test_expand_zero_dim_pass.py | 44 + .../test_transpose_floding_input_pass.py | 244 ++ .../test_transpose_floding_output_pass.py | 108 + test/cinn/pool_utils.py | 422 +++ test/cinn/test_common.py | 42 + test/cinn/test_computation.py | 130 + test/cinn/test_efficientnet.py | 108 + test/cinn/test_facedet.py | 109 + test/cinn/test_frontend.py | 191 ++ test/cinn/test_hlir_framework.py | 34 + test/cinn/test_ir.py | 50 + test/cinn/test_matmul.py | 133 + test/cinn/test_mobilenetv1.py | 109 + test/cinn/test_mobilenetv2.py | 112 + test/cinn/test_netbuilder.py | 118 + test/cinn/test_op_benchmark.py | 479 +++ test/cinn/test_op_broadcast.py | 104 + test/cinn/test_op_nn.py | 595 ++++ test/cinn/test_op_transform.py | 212 ++ test/cinn/test_packed_func.py | 75 + test/cinn/test_paddle_model_convertor.py | 269 ++ test/cinn/test_pe_elementwise.py | 164 + test/cinn/test_pe_reduction.py | 179 + test/cinn/test_pe_transform.py | 136 + test/cinn/test_resnet.py | 89 + test/cinn/test_resnet18.py | 110 + test/cinn/test_resnet50.py | 113 + test/cinn/test_squeezenet.py | 108 + test/cinn/test_utils.py | 152 + test/cpp/cinn/CMakeLists.txt | 22 + test/cpp/cinn/benchmark/CMakeLists.txt | 11 + .../cinn/benchmark/test_all_ops_default.cc | 368 ++ test/cpp/cinn/benchmark/test_elementwise.cc | 51 + test/cpp/cinn/benchmark/test_elementwise.h | 54 + test/cpp/cinn/benchmark/test_matmul.cc | 305 ++ test/cpp/cinn/benchmark/test_matmul.h | 122 + test/cpp/cinn/benchmark/test_utils.cc | 232 ++ test/cpp/cinn/benchmark/test_utils.h | 96 + test/cpp/cinn/concrete_program_builder.h | 100 + test/cpp/cinn/program_builder.cc | 63 + test/cpp/cinn/program_builder.h | 97 + test/cpp/cinn/test01_elementwise_add_case.cc | 162 + test/cpp/cinn/test01_elementwise_add_main.cc | 145 + test/cpp/cinn/test02_helper.h | 308 ++ test/cpp/cinn/test02_matmul_case.cc | 222 ++ test/cpp/cinn/test02_matmul_main.cc | 333 ++ test/cpp/cinn/test03_convolution_case.cc | 27 + test/cpp/cinn/test03_convolution_main.cc | 70 + 1056 files changed, 188306 insertions(+), 1 deletion(-) create mode 100644 paddle/cinn/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/analysis/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/analysis/analyze_ir.cc create mode 100644 paddle/cinn/auto_schedule/analysis/analyze_ir.h create mode 100644 paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc create mode 100644 paddle/cinn/auto_schedule/auto_schedule.proto create mode 100644 paddle/cinn/auto_schedule/auto_tuner.cc create mode 100644 paddle/cinn/auto_schedule/auto_tuner.h create mode 100644 paddle/cinn/auto_schedule/auto_tuner_test.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/expr_cost_model.h create mode 100644 paddle/cinn/auto_schedule/cost_model/feature.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/feature.h create mode 100644 paddle/cinn/auto_schedule/cost_model/feature_extractor.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/feature_extractor.h create mode 100644 paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/feature_test.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc create mode 100644 paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h create mode 100644 paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc create mode 100644 paddle/cinn/auto_schedule/database/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/database/database.cc create mode 100644 paddle/cinn/auto_schedule/database/database.h create mode 100644 paddle/cinn/auto_schedule/database/database_test.cc create mode 100644 paddle/cinn/auto_schedule/database/jsonfile_database.cc create mode 100644 paddle/cinn/auto_schedule/database/jsonfile_database.h create mode 100644 paddle/cinn/auto_schedule/database/jsonfile_database_test.cc create mode 100644 paddle/cinn/auto_schedule/measure/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/measure/measure.h create mode 100644 paddle/cinn/auto_schedule/measure/measurer_test.cc create mode 100644 paddle/cinn/auto_schedule/measure/schedule_measurer.cc create mode 100644 paddle/cinn/auto_schedule/measure/schedule_measurer.h create mode 100644 paddle/cinn/auto_schedule/measure/simple_builder.cc create mode 100644 paddle/cinn/auto_schedule/measure/simple_builder.h create mode 100644 paddle/cinn/auto_schedule/measure/simple_runner.cc create mode 100644 paddle/cinn/auto_schedule/measure/simple_runner.h create mode 100644 paddle/cinn/auto_schedule/measure/simple_runner_test.cc create mode 100644 paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc create mode 100644 paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h create mode 100644 paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc create mode 100644 paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h create mode 100644 paddle/cinn/auto_schedule/search_space/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc create mode 100644 paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h create mode 100644 paddle/cinn/auto_schedule/search_space/block_sampler.cc create mode 100644 paddle/cinn/auto_schedule/search_space/block_sampler.h create mode 100644 paddle/cinn/auto_schedule/search_space/block_sampler_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/rule_sampler.cc create mode 100644 paddle/cinn/auto_schedule/search_space/rule_sampler.h create mode 100644 paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/search_space.cc create mode 100644 paddle/cinn/auto_schedule/search_space/search_space.h create mode 100644 paddle/cinn/auto_schedule/search_space/search_space_test.cc create mode 100644 paddle/cinn/auto_schedule/search_space/search_state.cc create mode 100644 paddle/cinn/auto_schedule/search_space/search_state.h create mode 100644 paddle/cinn/auto_schedule/search_space/search_state_test.cc create mode 100644 paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc create mode 100644 paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h create mode 100644 paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h create mode 100644 paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc create mode 100644 paddle/cinn/auto_schedule/task/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/task/task_creator.cc create mode 100644 paddle/cinn/auto_schedule/task/task_creator.h create mode 100644 paddle/cinn/auto_schedule/task/task_creator_test.cc create mode 100644 paddle/cinn/auto_schedule/task/task_optimizer.cc create mode 100644 paddle/cinn/auto_schedule/task/task_optimizer.h create mode 100644 paddle/cinn/auto_schedule/task/task_registry.h create mode 100644 paddle/cinn/auto_schedule/task/task_registry_test.cc create mode 100644 paddle/cinn/auto_schedule/task/tune_task.cc create mode 100644 paddle/cinn/auto_schedule/task/tune_task.h create mode 100755 paddle/cinn/auto_schedule/task/tune_task_test.cc create mode 100644 paddle/cinn/auto_schedule/task_scheduler/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc create mode 100644 paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h create mode 100644 paddle/cinn/auto_schedule/task_scheduler/round_robin.cc create mode 100644 paddle/cinn/auto_schedule/task_scheduler/round_robin.h create mode 100644 paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc create mode 100644 paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h create mode 100644 paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc create mode 100644 paddle/cinn/auto_schedule/tests/CMakeLists.txt create mode 100644 paddle/cinn/auto_schedule/tests/performance_comparison_test.cc create mode 100644 paddle/cinn/auto_schedule/tuning.h create mode 100755 paddle/cinn/backends/CMakeLists.txt create mode 100644 paddle/cinn/backends/_x86_builtin_source.cc create mode 100644 paddle/cinn/backends/codegen_c.cc create mode 100755 paddle/cinn/backends/codegen_c.h create mode 100755 paddle/cinn/backends/codegen_c_test.cc create mode 100644 paddle/cinn/backends/codegen_c_x86.cc create mode 100644 paddle/cinn/backends/codegen_c_x86.h create mode 100644 paddle/cinn/backends/codegen_c_x86_test.cc create mode 100644 paddle/cinn/backends/codegen_cuda_dev.cc create mode 100644 paddle/cinn/backends/codegen_cuda_dev.h create mode 100644 paddle/cinn/backends/codegen_cuda_generate_test.cc create mode 100644 paddle/cinn/backends/codegen_cuda_host.cc create mode 100644 paddle/cinn/backends/codegen_cuda_host.h create mode 100644 paddle/cinn/backends/codegen_cuda_util.cc create mode 100755 paddle/cinn/backends/codegen_cuda_util.h create mode 100644 paddle/cinn/backends/codegen_debug_test.cc create mode 100644 paddle/cinn/backends/compiler.cc create mode 100644 paddle/cinn/backends/compiler.h create mode 100644 paddle/cinn/backends/compiler_test.cc create mode 100644 paddle/cinn/backends/cuda_util.cc create mode 100644 paddle/cinn/backends/cuda_util.h create mode 100644 paddle/cinn/backends/extern_func_emitter.cc create mode 100644 paddle/cinn/backends/extern_func_emitter.h create mode 100644 paddle/cinn/backends/extern_func_emitter_builtin.cc create mode 100644 paddle/cinn/backends/extern_func_emitter_builtin.h create mode 100644 paddle/cinn/backends/extern_func_jit_register.cc create mode 100644 paddle/cinn/backends/extern_func_jit_register.h create mode 100644 paddle/cinn/backends/extern_func_protos.cc create mode 100644 paddle/cinn/backends/extern_func_protos.h create mode 100644 paddle/cinn/backends/function_prototype.cc create mode 100644 paddle/cinn/backends/function_prototype.h create mode 100644 paddle/cinn/backends/generated1.cu create mode 100644 paddle/cinn/backends/generated_module1.cc create mode 100644 paddle/cinn/backends/ir_schedule_test.cc create mode 100755 paddle/cinn/backends/llvm/CMakeLists.txt create mode 100644 paddle/cinn/backends/llvm/codegen_llvm.cc create mode 100644 paddle/cinn/backends/llvm/codegen_llvm.h create mode 100644 paddle/cinn/backends/llvm/codegen_llvm_test.cc create mode 100644 paddle/cinn/backends/llvm/codegen_x86.cc create mode 100644 paddle/cinn/backends/llvm/codegen_x86.h create mode 100644 paddle/cinn/backends/llvm/codegen_x86_test.cc create mode 100644 paddle/cinn/backends/llvm/execution_engine.cc create mode 100644 paddle/cinn/backends/llvm/execution_engine.h create mode 100644 paddle/cinn/backends/llvm/execution_engine_test.cc create mode 100644 paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py create mode 100644 paddle/cinn/backends/llvm/ir_builder_mixin.h create mode 100644 paddle/cinn/backends/llvm/llvm_intrin_rule.h create mode 100644 paddle/cinn/backends/llvm/llvm_optimizer.cc create mode 100644 paddle/cinn/backends/llvm/llvm_optimizer.h create mode 100644 paddle/cinn/backends/llvm/llvm_util.cc create mode 100644 paddle/cinn/backends/llvm/llvm_util.h create mode 100644 paddle/cinn/backends/llvm/runtime_symbol_registry.cc create mode 100644 paddle/cinn/backends/llvm/runtime_symbol_registry.h create mode 100755 paddle/cinn/backends/llvm/simple_jit.cc create mode 100755 paddle/cinn/backends/llvm/simple_jit.h create mode 100644 paddle/cinn/backends/modular.cc create mode 100644 paddle/cinn/backends/modular.h create mode 100644 paddle/cinn/backends/nvrtc/CMakeLists.txt create mode 100644 paddle/cinn/backends/nvrtc/header_generator.cc create mode 100644 paddle/cinn/backends/nvrtc/header_generator.h create mode 100644 paddle/cinn/backends/nvrtc/nvrtc_util.cc create mode 100644 paddle/cinn/backends/nvrtc/nvrtc_util.h create mode 100644 paddle/cinn/backends/nvrtc/nvrtc_util_test.cc create mode 100644 paddle/cinn/backends/outputs.cc create mode 100644 paddle/cinn/backends/outputs.h create mode 100644 paddle/cinn/backends/raw_cuda_code_test.cu create mode 100644 paddle/cinn/cinn.h create mode 100644 paddle/cinn/common/CMakeLists.txt create mode 100644 paddle/cinn/common/arithmatic.cc create mode 100644 paddle/cinn/common/arithmatic.h create mode 100644 paddle/cinn/common/arithmatic_test.cc create mode 100644 paddle/cinn/common/axis.cc create mode 100644 paddle/cinn/common/axis.h create mode 100644 paddle/cinn/common/axis_test.cc create mode 100644 paddle/cinn/common/bfloat16.h create mode 100644 paddle/cinn/common/cas.cc create mode 100755 paddle/cinn/common/cas.h create mode 100644 paddle/cinn/common/cas_test.cc create mode 100644 paddle/cinn/common/cinn_value.cc create mode 100755 paddle/cinn/common/cinn_value.h create mode 100644 paddle/cinn/common/cinn_value_test.cc create mode 100644 paddle/cinn/common/common.h create mode 100644 paddle/cinn/common/context.cc create mode 100644 paddle/cinn/common/context.h create mode 100644 paddle/cinn/common/cost_model.h create mode 100644 paddle/cinn/common/cuda_test_helper.cc create mode 100644 paddle/cinn/common/cuda_test_helper.h create mode 100644 paddle/cinn/common/debug_manager.cc create mode 100644 paddle/cinn/common/debug_manager.h create mode 100644 paddle/cinn/common/float16.h create mode 100644 paddle/cinn/common/float16_bfloat16_cuda_test.cu create mode 100644 paddle/cinn/common/float16_bfloat16_host_test.cc create mode 100644 paddle/cinn/common/float16_bfloat16_utils.h create mode 100755 paddle/cinn/common/graph_utils.cc create mode 100644 paddle/cinn/common/graph_utils.h create mode 100644 paddle/cinn/common/graph_utils_test.cc create mode 100644 paddle/cinn/common/info_registry.cc create mode 100644 paddle/cinn/common/info_registry.h create mode 100755 paddle/cinn/common/ir_util.cc create mode 100644 paddle/cinn/common/ir_util.h create mode 100644 paddle/cinn/common/macros.h create mode 100644 paddle/cinn/common/object.cc create mode 100644 paddle/cinn/common/object.h create mode 100644 paddle/cinn/common/python_interpreter_guard.cc create mode 100644 paddle/cinn/common/python_interpreter_guard.h create mode 100644 paddle/cinn/common/shared.cc create mode 100644 paddle/cinn/common/shared.h create mode 100644 paddle/cinn/common/shared_test.cc create mode 100644 paddle/cinn/common/target.cc create mode 100755 paddle/cinn/common/target.h create mode 100644 paddle/cinn/common/test_helper.cc create mode 100644 paddle/cinn/common/test_helper.h create mode 100644 paddle/cinn/common/type.cc create mode 100644 paddle/cinn/common/type.h create mode 100644 paddle/cinn/common/type_test.cc create mode 100644 paddle/cinn/common/union_find.cc create mode 100644 paddle/cinn/common/union_find.h create mode 100755 paddle/cinn/frontend/CMakeLists.txt create mode 100644 paddle/cinn/frontend/computation.cc create mode 100644 paddle/cinn/frontend/computation.h create mode 100644 paddle/cinn/frontend/computation_test.cc create mode 100755 paddle/cinn/frontend/decomposer/CMakeLists.txt create mode 100644 paddle/cinn/frontend/decomposer/activation.cc create mode 100644 paddle/cinn/frontend/decomposer/activation_test.cc create mode 100644 paddle/cinn/frontend/decomposer/batch_norm.cc create mode 100755 paddle/cinn/frontend/decomposer/batch_norm_test.cc create mode 100644 paddle/cinn/frontend/decomposer/broadcast.cc create mode 100644 paddle/cinn/frontend/decomposer/broadcast_test.cc create mode 100644 paddle/cinn/frontend/decomposer/elementwise.cc create mode 100644 paddle/cinn/frontend/decomposer/elementwise_test.cc create mode 100644 paddle/cinn/frontend/decomposer/test_helper.cc create mode 100644 paddle/cinn/frontend/decomposer/test_helper.h create mode 100644 paddle/cinn/frontend/decomposer/top_k.cc create mode 100644 paddle/cinn/frontend/decomposer/top_k_test.cc create mode 100644 paddle/cinn/frontend/decomposer/use_decomposer.h create mode 100644 paddle/cinn/frontend/decomposer_registry.h create mode 100644 paddle/cinn/frontend/decomposer_registry_test.cc create mode 100755 paddle/cinn/frontend/interpreter.cc create mode 100755 paddle/cinn/frontend/interpreter.h create mode 100755 paddle/cinn/frontend/interpreter_test.cc create mode 100644 paddle/cinn/frontend/net_builder.cc create mode 100644 paddle/cinn/frontend/net_builder.h create mode 100644 paddle/cinn/frontend/net_builder_test.cc create mode 100644 paddle/cinn/frontend/op_mapper_registry.cc create mode 100644 paddle/cinn/frontend/op_mapper_registry.h create mode 100644 paddle/cinn/frontend/op_mapper_registry_test.cc create mode 100644 paddle/cinn/frontend/op_mappers/CMakeLists.txt create mode 100644 paddle/cinn/frontend/op_mappers/common_utils.h create mode 100644 paddle/cinn/frontend/op_mappers/paddle/CMakeLists.txt create mode 100644 paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/argsort.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/atan.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/binary.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/cholesky.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/clip.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/compare.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/concat.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/constant.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/conv2d.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/cumsum.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/dropout.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/elementwise.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/expand.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/fetch_feed.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/flip.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/gather.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/gather_nd.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/gaussian_random.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/layer_norm.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/log.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/lookup_table.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/matmul.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/mul.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/norm.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/one_hot.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/pool2d.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/randint.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/reduce.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/relu.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/reshape.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/reverse.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/roll.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/scale.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/scatter.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/slice.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/softmax.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/squeeze.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/strided_slice.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/take_along_axis.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/tile.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/top_k.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/transpose.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/triangular_solve.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/unary.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/uniform_random.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/unsqueeze.cc create mode 100644 paddle/cinn/frontend/op_mappers/paddle/where.cc create mode 100644 paddle/cinn/frontend/op_mappers/science/CMakeLists.txt create mode 100644 paddle/cinn/frontend/op_mappers/science/broadcast.cc create mode 100644 paddle/cinn/frontend/op_mappers/science/compare.cc create mode 100644 paddle/cinn/frontend/op_mappers/science/math.cc create mode 100644 paddle/cinn/frontend/op_mappers/science/transform.cc create mode 100644 paddle/cinn/frontend/op_mappers/use_op_mappers.h create mode 100644 paddle/cinn/frontend/optimize.cc create mode 100755 paddle/cinn/frontend/optimize.h create mode 100644 paddle/cinn/frontend/paddle/CMakeLists.txt create mode 100644 paddle/cinn/frontend/paddle/README.md create mode 100644 paddle/cinn/frontend/paddle/compatible_pb.cc create mode 100644 paddle/cinn/frontend/paddle/compatible_pb.h create mode 100644 paddle/cinn/frontend/paddle/cpp/CMakeLists.txt create mode 100644 paddle/cinn/frontend/paddle/cpp/block_desc.cc create mode 100644 paddle/cinn/frontend/paddle/cpp/block_desc.h create mode 100644 paddle/cinn/frontend/paddle/cpp/desc_api.h create mode 100644 paddle/cinn/frontend/paddle/cpp/op_desc.cc create mode 100644 paddle/cinn/frontend/paddle/cpp/op_desc.h create mode 100644 paddle/cinn/frontend/paddle/cpp/program_desc.cc create mode 100644 paddle/cinn/frontend/paddle/cpp/program_desc.h create mode 100644 paddle/cinn/frontend/paddle/cpp/var_desc.cc create mode 100644 paddle/cinn/frontend/paddle/cpp/var_desc.h create mode 100644 paddle/cinn/frontend/paddle/framework.proto create mode 100755 paddle/cinn/frontend/paddle/model_parser.cc create mode 100644 paddle/cinn/frontend/paddle/model_parser.h create mode 100644 paddle/cinn/frontend/paddle/model_parser_test.cc create mode 100644 paddle/cinn/frontend/paddle/pb/CMakeLists.txt create mode 100644 paddle/cinn/frontend/paddle/pb/block_desc.cc create mode 100644 paddle/cinn/frontend/paddle/pb/block_desc.h create mode 100644 paddle/cinn/frontend/paddle/pb/op_desc.cc create mode 100644 paddle/cinn/frontend/paddle/pb/op_desc.h create mode 100644 paddle/cinn/frontend/paddle/pb/program_desc.cc create mode 100644 paddle/cinn/frontend/paddle/pb/program_desc.h create mode 100644 paddle/cinn/frontend/paddle/pb/var_desc.cc create mode 100644 paddle/cinn/frontend/paddle/pb/var_desc.h create mode 100644 paddle/cinn/frontend/paddle_model_convertor.cc create mode 100644 paddle/cinn/frontend/paddle_model_convertor.h create mode 100644 paddle/cinn/frontend/paddle_model_convertor_test.cc create mode 100644 paddle/cinn/frontend/paddle_model_to_program.cc create mode 100644 paddle/cinn/frontend/paddle_model_to_program.h create mode 100755 paddle/cinn/frontend/pass/CMakeLists.txt create mode 100644 paddle/cinn/frontend/pass/auto_broadcast.cc create mode 100644 paddle/cinn/frontend/pass/auto_cast.cc create mode 100644 paddle/cinn/frontend/pass/auto_cast_test.cc create mode 100644 paddle/cinn/frontend/pass/cast_collapsing.cc create mode 100644 paddle/cinn/frontend/pass/cast_collapsing_test.cc create mode 100644 paddle/cinn/frontend/pass/dead_code_eliminate.cc create mode 100644 paddle/cinn/frontend/pass/dead_code_eliminate_test.cc create mode 100755 paddle/cinn/frontend/pass/decomposer.cc create mode 100644 paddle/cinn/frontend/pass/decomposer_test.cc create mode 100644 paddle/cinn/frontend/pass/expand_zero_dim_pass.cc create mode 100644 paddle/cinn/frontend/pass/expand_zero_dim_pass_test.cc create mode 100644 paddle/cinn/frontend/pass/fill_constant_folding.cc create mode 100644 paddle/cinn/frontend/pass/fill_constant_folding_test.cc create mode 100644 paddle/cinn/frontend/pass/fill_constant_rewriter.cc create mode 100644 paddle/cinn/frontend/pass/fill_constant_rewriter_test.cc create mode 100644 paddle/cinn/frontend/pass/gemm_rewriter.cc create mode 100755 paddle/cinn/frontend/pass/gemm_rewriter_test.cc create mode 100644 paddle/cinn/frontend/pass/pass_test_helper.h create mode 100644 paddle/cinn/frontend/pass/program_topoerror_test.cc create mode 100644 paddle/cinn/frontend/pass/remove_identity.cc create mode 100644 paddle/cinn/frontend/pass/remove_identity_test.cc create mode 100644 paddle/cinn/frontend/pass/test_helper.h create mode 100644 paddle/cinn/frontend/pass/transpose_collapsing.cc create mode 100644 paddle/cinn/frontend/pass/transpose_collapsing_test.cc create mode 100644 paddle/cinn/frontend/pass/transpose_folding_base.h create mode 100644 paddle/cinn/frontend/pass/transpose_folding_input.cc create mode 100644 paddle/cinn/frontend/pass/transpose_folding_input_test.cc create mode 100644 paddle/cinn/frontend/pass/transpose_folding_output.cc create mode 100755 paddle/cinn/frontend/pass/transpose_folding_output_test.cc create mode 100644 paddle/cinn/frontend/pass/transpose_scale_folding_test.cc create mode 100644 paddle/cinn/frontend/pass/use_program_pass.h create mode 100644 paddle/cinn/frontend/program_pass.cc create mode 100755 paddle/cinn/frontend/program_pass.h create mode 100644 paddle/cinn/frontend/syntax.cc create mode 100644 paddle/cinn/frontend/syntax.h create mode 100644 paddle/cinn/frontend/syntax_test.cc create mode 100644 paddle/cinn/frontend/var_type_utils.h create mode 100644 paddle/cinn/gtest_main.cc create mode 100644 paddle/cinn/hlir/CMakeLists.txt create mode 100755 paddle/cinn/hlir/framework/CMakeLists.txt create mode 100644 paddle/cinn/hlir/framework/accuracy_checker.cc create mode 100644 paddle/cinn/hlir/framework/accuracy_checker.h create mode 100644 paddle/cinn/hlir/framework/accuracy_checker_test.cc create mode 100755 paddle/cinn/hlir/framework/buffer.cc create mode 100644 paddle/cinn/hlir/framework/buffer.h create mode 100755 paddle/cinn/hlir/framework/buffer_test.cc create mode 100644 paddle/cinn/hlir/framework/graph.cc create mode 100644 paddle/cinn/hlir/framework/graph.h create mode 100644 paddle/cinn/hlir/framework/graph_compiler.cc create mode 100644 paddle/cinn/hlir/framework/graph_compiler.h create mode 100644 paddle/cinn/hlir/framework/graph_compiler_test.cc create mode 100644 paddle/cinn/hlir/framework/graph_test.cc create mode 100644 paddle/cinn/hlir/framework/instruction.cc create mode 100644 paddle/cinn/hlir/framework/instruction.h create mode 100644 paddle/cinn/hlir/framework/instruction_test.cc create mode 100755 paddle/cinn/hlir/framework/memory.cc create mode 100755 paddle/cinn/hlir/framework/memory.h create mode 100644 paddle/cinn/hlir/framework/node.cc create mode 100644 paddle/cinn/hlir/framework/node.h create mode 100755 paddle/cinn/hlir/framework/op.h create mode 100644 paddle/cinn/hlir/framework/op_lowering.cc create mode 100755 paddle/cinn/hlir/framework/op_lowering.h create mode 100644 paddle/cinn/hlir/framework/op_lowering_test.cc create mode 100644 paddle/cinn/hlir/framework/op_lowering_util.cc create mode 100644 paddle/cinn/hlir/framework/op_lowering_util.h create mode 100644 paddle/cinn/hlir/framework/op_strategy.cc create mode 100644 paddle/cinn/hlir/framework/op_strategy.h create mode 100644 paddle/cinn/hlir/framework/op_test.cc create mode 100644 paddle/cinn/hlir/framework/parallel_compiler.cc create mode 100644 paddle/cinn/hlir/framework/parallel_compiler.h create mode 100644 paddle/cinn/hlir/framework/parallel_compiler_test.cc create mode 100644 paddle/cinn/hlir/framework/pass.cc create mode 100644 paddle/cinn/hlir/framework/pass.h create mode 100644 paddle/cinn/hlir/framework/print_graph_pass_test.cc create mode 100644 paddle/cinn/hlir/framework/schedule.h create mode 100755 paddle/cinn/hlir/framework/scope.cc create mode 100755 paddle/cinn/hlir/framework/scope.h create mode 100644 paddle/cinn/hlir/framework/scope_test.cc create mode 100644 paddle/cinn/hlir/framework/tensor.cc create mode 100644 paddle/cinn/hlir/framework/tensor.h create mode 100644 paddle/cinn/hlir/framework/tensor_test.cc create mode 100644 paddle/cinn/hlir/framework/variable.cc create mode 100644 paddle/cinn/hlir/framework/variable.h create mode 100644 paddle/cinn/hlir/framework/visualize_helper.cc create mode 100644 paddle/cinn/hlir/framework/visualize_helper.h create mode 100644 paddle/cinn/hlir/kernels/CMakeLists.txt create mode 100644 paddle/cinn/hlir/op/CMakeLists.txt create mode 100644 paddle/cinn/hlir/op/broadcast.cc create mode 100644 paddle/cinn/hlir/op/contrib/CMakeLists.txt create mode 100644 paddle/cinn/hlir/op/contrib/argmax.cc create mode 100644 paddle/cinn/hlir/op/contrib/argmax.h create mode 100644 paddle/cinn/hlir/op/contrib/argmax_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/argmin.cc create mode 100644 paddle/cinn/hlir/op/contrib/argmin.h create mode 100644 paddle/cinn/hlir/op/contrib/argmin_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/assert_true.cc create mode 100644 paddle/cinn/hlir/op/contrib/bitcast_convert.cc create mode 100644 paddle/cinn/hlir/op/contrib/cholesky.cc create mode 100644 paddle/cinn/hlir/op/contrib/gather_nd.cc create mode 100644 paddle/cinn/hlir/op/contrib/gather_nd.h create mode 100644 paddle/cinn/hlir/op/contrib/gather_nd_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/gaussian_random.cc create mode 100644 paddle/cinn/hlir/op/contrib/logical_right_shift.cc create mode 100644 paddle/cinn/hlir/op/contrib/logical_right_shift.h create mode 100644 paddle/cinn/hlir/op/contrib/logical_right_shift_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/lookup_table.cc create mode 100644 paddle/cinn/hlir/op/contrib/lookup_table.h create mode 100644 paddle/cinn/hlir/op/contrib/lookup_table_test.cc create mode 100755 paddle/cinn/hlir/op/contrib/one_hot.cc create mode 100644 paddle/cinn/hlir/op/contrib/one_hot.h create mode 100644 paddle/cinn/hlir/op/contrib/one_hot_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/randint.cc create mode 100644 paddle/cinn/hlir/op/contrib/reciprocal.cc create mode 100644 paddle/cinn/hlir/op/contrib/reciprocal.h create mode 100644 paddle/cinn/hlir/op/contrib/reciprocal_test.cc create mode 100755 paddle/cinn/hlir/op/contrib/repeat.cc create mode 100644 paddle/cinn/hlir/op/contrib/repeat.h create mode 100755 paddle/cinn/hlir/op/contrib/repeat_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/resize.cc create mode 100644 paddle/cinn/hlir/op/contrib/resize.h create mode 100644 paddle/cinn/hlir/op/contrib/sort.cc create mode 100644 paddle/cinn/hlir/op/contrib/sort.h create mode 100644 paddle/cinn/hlir/op/contrib/sort_test.cc create mode 100644 paddle/cinn/hlir/op/contrib/triangular_solve.cc create mode 100644 paddle/cinn/hlir/op/contrib/uniform_random.cc create mode 100644 paddle/cinn/hlir/op/custom_call.cc create mode 100644 paddle/cinn/hlir/op/elementwise.cc create mode 100644 paddle/cinn/hlir/op/external_api_registry.cc create mode 100644 paddle/cinn/hlir/op/external_api_registry.h create mode 100644 paddle/cinn/hlir/op/external_api_registry_test.cc create mode 100644 paddle/cinn/hlir/op/nn.cc create mode 100755 paddle/cinn/hlir/op/op_broadcast_test.cc create mode 100644 paddle/cinn/hlir/op/op_nn_test.cc create mode 100644 paddle/cinn/hlir/op/op_util.cc create mode 100644 paddle/cinn/hlir/op/op_util.h create mode 100644 paddle/cinn/hlir/op/reduction.cc create mode 100644 paddle/cinn/hlir/op/reduction_test.cc create mode 100644 paddle/cinn/hlir/op/transform.cc create mode 100644 paddle/cinn/hlir/op/transform_test.cc create mode 100644 paddle/cinn/hlir/op/use_ops.h create mode 100644 paddle/cinn/hlir/pass/CMakeLists.txt create mode 100644 paddle/cinn/hlir/pass/alterlayout.cc create mode 100755 paddle/cinn/hlir/pass/alterlayout_test.cc create mode 100644 paddle/cinn/hlir/pass/check_fusion_accuracy_pass.cc create mode 100644 paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/common_subexpression_elimination.cc create mode 100644 paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc create mode 100644 paddle/cinn/hlir/pass/const_propagate.cc create mode 100644 paddle/cinn/hlir/pass/const_propagate_test.cc create mode 100644 paddle/cinn/hlir/pass/constant_folding_pass.cc create mode 100644 paddle/cinn/hlir/pass/constant_folding_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/constant_folding_pass_util.cc create mode 100644 paddle/cinn/hlir/pass/constant_folding_pass_util.h create mode 100644 paddle/cinn/hlir/pass/custom_call_pass.cc create mode 100644 paddle/cinn/hlir/pass/dce_pass.cc create mode 100644 paddle/cinn/hlir/pass/dce_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/dense_merge_pass.cc create mode 100644 paddle/cinn/hlir/pass/dense_merge_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/dot_merger.cc create mode 100644 paddle/cinn/hlir/pass/dot_merger_test.cc create mode 100644 paddle/cinn/hlir/pass/fusion_helper_base.h create mode 100644 paddle/cinn/hlir/pass/fusion_merge_pass.cc create mode 100755 paddle/cinn/hlir/pass/fusion_merge_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/fusion_merge_pass_util.h create mode 100755 paddle/cinn/hlir/pass/infershape.cc create mode 100644 paddle/cinn/hlir/pass/infershape.h create mode 100644 paddle/cinn/hlir/pass/op_fusion_pass.cc create mode 100755 paddle/cinn/hlir/pass/op_fusion_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/op_fusion_pass_util.h create mode 100644 paddle/cinn/hlir/pass/opfusion.cc create mode 100755 paddle/cinn/hlir/pass/opfusion_test.cc create mode 100644 paddle/cinn/hlir/pass/reduce_split_pass.cc create mode 100644 paddle/cinn/hlir/pass/reduce_split_pass_test.cc create mode 100644 paddle/cinn/hlir/pass/single_group_optimize_pass.cc create mode 100644 paddle/cinn/hlir/pass/test_dot_merger.cc create mode 100755 paddle/cinn/hlir/pass/test_primitive_ops.cc create mode 100644 paddle/cinn/hlir/pass/use_pass.h create mode 100755 paddle/cinn/hlir/pe/CMakeLists.txt create mode 100644 paddle/cinn/hlir/pe/broadcast.cc create mode 100644 paddle/cinn/hlir/pe/broadcast.h create mode 100644 paddle/cinn/hlir/pe/elementwise.cc create mode 100644 paddle/cinn/hlir/pe/elementwise.h create mode 100644 paddle/cinn/hlir/pe/ir_schedule_pe.cc create mode 100644 paddle/cinn/hlir/pe/ir_schedule_pe.h create mode 100644 paddle/cinn/hlir/pe/load_params_test.cc create mode 100644 paddle/cinn/hlir/pe/load_x86_params.cc create mode 100644 paddle/cinn/hlir/pe/load_x86_params.h create mode 100644 paddle/cinn/hlir/pe/nn.cc create mode 100755 paddle/cinn/hlir/pe/nn.h create mode 100644 paddle/cinn/hlir/pe/nn_util.cc create mode 100644 paddle/cinn/hlir/pe/nn_util.h create mode 100644 paddle/cinn/hlir/pe/pe_broadcast_test.cc create mode 100644 paddle/cinn/hlir/pe/pe_elementwise_test.cc create mode 100644 paddle/cinn/hlir/pe/pe_transform_test.cc create mode 100644 paddle/cinn/hlir/pe/reduction.cc create mode 100644 paddle/cinn/hlir/pe/reduction.h create mode 100644 paddle/cinn/hlir/pe/schedule.cc create mode 100644 paddle/cinn/hlir/pe/schedule.h create mode 100644 paddle/cinn/hlir/pe/schedule_param.proto create mode 100644 paddle/cinn/hlir/pe/transform.cc create mode 100644 paddle/cinn/hlir/pe/transform.h create mode 100644 paddle/cinn/hlir/pe/vision.cc create mode 100644 paddle/cinn/hlir/pe/vision.h create mode 100755 paddle/cinn/ir/CMakeLists.txt create mode 100755 paddle/cinn/ir/buffer.cc create mode 100755 paddle/cinn/ir/buffer.h create mode 100644 paddle/cinn/ir/buffer_test.cc create mode 100644 paddle/cinn/ir/collect_ir_nodes.cc create mode 100755 paddle/cinn/ir/collect_ir_nodes.h create mode 100644 paddle/cinn/ir/collect_ir_nodes_test.cc create mode 100644 paddle/cinn/ir/function_base.cc create mode 100644 paddle/cinn/ir/function_base.h create mode 100644 paddle/cinn/ir/function_definition.cc create mode 100644 paddle/cinn/ir/function_definition.h create mode 100644 paddle/cinn/ir/intrinsic_ops.cc create mode 100644 paddle/cinn/ir/intrinsic_ops.h create mode 100644 paddle/cinn/ir/intrinsic_ops_test.cc create mode 100755 paddle/cinn/ir/ir.cc create mode 100644 paddle/cinn/ir/ir.h create mode 100644 paddle/cinn/ir/ir_base.cc create mode 100644 paddle/cinn/ir/ir_base.h create mode 100644 paddle/cinn/ir/ir_compare.cc create mode 100644 paddle/cinn/ir/ir_compare.h create mode 100644 paddle/cinn/ir/ir_compare_test.cc create mode 100644 paddle/cinn/ir/ir_mutator.cc create mode 100755 paddle/cinn/ir/ir_mutator.h create mode 100644 paddle/cinn/ir/ir_operators.cc create mode 100644 paddle/cinn/ir/ir_operators.h create mode 100644 paddle/cinn/ir/ir_operators_test.cc create mode 100644 paddle/cinn/ir/ir_printer.cc create mode 100644 paddle/cinn/ir/ir_printer.h create mode 100644 paddle/cinn/ir/ir_printer_test.cc create mode 100644 paddle/cinn/ir/ir_schedule.cc create mode 100644 paddle/cinn/ir/ir_schedule.h create mode 100644 paddle/cinn/ir/ir_schedule_util.cc create mode 100644 paddle/cinn/ir/ir_schedule_util.h create mode 100644 paddle/cinn/ir/ir_test.cc create mode 100644 paddle/cinn/ir/ir_verify.cc create mode 100644 paddle/cinn/ir/ir_verify.h create mode 100644 paddle/cinn/ir/ir_verify_test.cc create mode 100644 paddle/cinn/ir/ir_visitor.cc create mode 100644 paddle/cinn/ir/ir_visitor.h create mode 100644 paddle/cinn/ir/layout.cc create mode 100644 paddle/cinn/ir/layout.h create mode 100644 paddle/cinn/ir/lowered_func.cc create mode 100755 paddle/cinn/ir/lowered_func.h create mode 100644 paddle/cinn/ir/module.cc create mode 100644 paddle/cinn/ir/module.h create mode 100644 paddle/cinn/ir/operation.cc create mode 100644 paddle/cinn/ir/operation.h create mode 100644 paddle/cinn/ir/registry.cc create mode 100644 paddle/cinn/ir/registry.h create mode 100644 paddle/cinn/ir/schedule_desc.cc create mode 100644 paddle/cinn/ir/schedule_desc.h create mode 100644 paddle/cinn/ir/schedule_desc.proto create mode 100644 paddle/cinn/ir/schedule_desc_test.cc create mode 100755 paddle/cinn/ir/tensor.cc create mode 100644 paddle/cinn/ir/tensor.h create mode 100755 paddle/cinn/ir/tensor_test.cc create mode 100644 paddle/cinn/lang/CMakeLists.txt create mode 100644 paddle/cinn/lang/README.md create mode 100644 paddle/cinn/lang/buffer.cc create mode 100644 paddle/cinn/lang/buffer.h create mode 100644 paddle/cinn/lang/builtin.cc create mode 100644 paddle/cinn/lang/builtin.h create mode 100644 paddle/cinn/lang/compute.cc create mode 100755 paddle/cinn/lang/compute.h create mode 100644 paddle/cinn/lang/compute_test.cc create mode 100755 paddle/cinn/lang/lower.cc create mode 100644 paddle/cinn/lang/lower.h create mode 100644 paddle/cinn/lang/lower_impl.cc create mode 100644 paddle/cinn/lang/lower_impl.h create mode 100644 paddle/cinn/lang/lower_impl_test.cc create mode 100755 paddle/cinn/lang/lower_test.cc create mode 100644 paddle/cinn/lang/packed_func.cc create mode 100644 paddle/cinn/lang/packed_func.h create mode 100644 paddle/cinn/lang/packed_func_test.cc create mode 100644 paddle/cinn/lang/placeholder.cc create mode 100644 paddle/cinn/lang/placeholder.h create mode 100644 paddle/cinn/lang/placeholder_test.cc create mode 100755 paddle/cinn/optim/CMakeLists.txt create mode 100644 paddle/cinn/optim/buffer_assign.cc create mode 100644 paddle/cinn/optim/buffer_assign.h create mode 100755 paddle/cinn/optim/cache_read_write_replace_test.cc create mode 100644 paddle/cinn/optim/call_arg_list_to_pod_value.cc create mode 100644 paddle/cinn/optim/call_arg_list_to_pod_value.h create mode 100644 paddle/cinn/optim/cast_bool_to_int8.cc create mode 100644 paddle/cinn/optim/cast_bool_to_int8.h create mode 100644 paddle/cinn/optim/cast_simplify.cc create mode 100644 paddle/cinn/optim/cast_simplify.h create mode 100644 paddle/cinn/optim/cast_simplify_test.cc create mode 100644 paddle/cinn/optim/collect_undefined_vars.cc create mode 100644 paddle/cinn/optim/collect_undefined_vars.h create mode 100644 paddle/cinn/optim/compute_inline_expand.cc create mode 100644 paddle/cinn/optim/compute_inline_expand.h create mode 100644 paddle/cinn/optim/eliminate_broadcast_in_forloop.cc create mode 100644 paddle/cinn/optim/eliminate_broadcast_in_forloop.h create mode 100644 paddle/cinn/optim/extern_call_process.cc create mode 100644 paddle/cinn/optim/extern_call_process.h create mode 100644 paddle/cinn/optim/fold_cinn_call_arguments.cc create mode 100644 paddle/cinn/optim/fold_cinn_call_arguments.h create mode 100644 paddle/cinn/optim/if_simplify.cc create mode 100644 paddle/cinn/optim/if_simplify.h create mode 100644 paddle/cinn/optim/if_simplify_test.cc create mode 100644 paddle/cinn/optim/insert_debug_log_callee.cc create mode 100644 paddle/cinn/optim/insert_debug_log_callee.h create mode 100644 paddle/cinn/optim/ir_copy.cc create mode 100644 paddle/cinn/optim/ir_copy.h create mode 100644 paddle/cinn/optim/ir_copy_test.cc create mode 100755 paddle/cinn/optim/ir_replace.cc create mode 100644 paddle/cinn/optim/ir_replace.h create mode 100644 paddle/cinn/optim/ir_simplify.cc create mode 100644 paddle/cinn/optim/ir_simplify.h create mode 100755 paddle/cinn/optim/ir_simplify_test.cc create mode 100644 paddle/cinn/optim/lower_function_call_bind_vars.cc create mode 100644 paddle/cinn/optim/lower_function_call_bind_vars.h create mode 100644 paddle/cinn/optim/lower_intrin.cc create mode 100644 paddle/cinn/optim/lower_intrin.h create mode 100644 paddle/cinn/optim/map_extern_call.cc create mode 100644 paddle/cinn/optim/map_extern_call.h create mode 100644 paddle/cinn/optim/optimize.cc create mode 100644 paddle/cinn/optim/optimize.h create mode 100755 paddle/cinn/optim/optimize_test.cc create mode 100644 paddle/cinn/optim/remove_nested_block.cc create mode 100644 paddle/cinn/optim/remove_nested_block.h create mode 100644 paddle/cinn/optim/remove_nested_block_test.cc create mode 100644 paddle/cinn/optim/remove_schedule_block.cc create mode 100644 paddle/cinn/optim/remove_schedule_block.h create mode 100755 paddle/cinn/optim/remove_schedule_block_test.cc create mode 100644 paddle/cinn/optim/replace_call_with_expr.cc create mode 100644 paddle/cinn/optim/replace_call_with_expr.h create mode 100644 paddle/cinn/optim/replace_call_with_expr_test.cc create mode 100644 paddle/cinn/optim/replace_const_param_to_integer.cc create mode 100644 paddle/cinn/optim/replace_const_param_to_integer.h create mode 100644 paddle/cinn/optim/replace_var_with_expr.cc create mode 100644 paddle/cinn/optim/replace_var_with_expr.h create mode 100644 paddle/cinn/optim/tensor_write_tell.cc create mode 100644 paddle/cinn/optim/tensor_write_tell.h create mode 100644 paddle/cinn/optim/transform_gpu_forloop.cc create mode 100644 paddle/cinn/optim/transform_gpu_forloop.h create mode 100644 paddle/cinn/optim/transform_polyfor_to_for.cc create mode 100644 paddle/cinn/optim/transform_polyfor_to_for.h create mode 100644 paddle/cinn/optim/transform_polyfor_to_for_test.cc create mode 100755 paddle/cinn/optim/unroll_loops.cc create mode 100644 paddle/cinn/optim/unroll_loops.h create mode 100644 paddle/cinn/optim/unroll_loops_test.cc create mode 100644 paddle/cinn/optim/var_mod_simplify.cc create mode 100644 paddle/cinn/optim/var_mod_simplify.h create mode 100644 paddle/cinn/optim/vectorize_loops.cc create mode 100644 paddle/cinn/optim/vectorize_loops.h create mode 100644 paddle/cinn/optim/vectorize_loops_test.cc create mode 100644 paddle/cinn/poly/CMakeLists.txt create mode 100644 paddle/cinn/poly/ast_gen.cc create mode 100644 paddle/cinn/poly/ast_gen.h create mode 100644 paddle/cinn/poly/ast_gen_test.cc create mode 100755 paddle/cinn/poly/compute_at_transform.cc create mode 100644 paddle/cinn/poly/compute_at_transform.h create mode 100644 paddle/cinn/poly/compute_at_transform_test.cc create mode 100644 paddle/cinn/poly/dim.cc create mode 100644 paddle/cinn/poly/dim.h create mode 100644 paddle/cinn/poly/domain.cc create mode 100644 paddle/cinn/poly/domain.h create mode 100644 paddle/cinn/poly/domain_add_unit_loop_mutator.cc create mode 100644 paddle/cinn/poly/domain_add_unit_loop_mutator.h create mode 100755 paddle/cinn/poly/graph.cc create mode 100644 paddle/cinn/poly/graph.h create mode 100644 paddle/cinn/poly/graph_test.cc create mode 100644 paddle/cinn/poly/isl_utils.cc create mode 100644 paddle/cinn/poly/isl_utils.h create mode 100644 paddle/cinn/poly/isl_utils_test.cc create mode 100644 paddle/cinn/poly/map.cc create mode 100644 paddle/cinn/poly/map.h create mode 100644 paddle/cinn/poly/naive_scheduler.cc create mode 100644 paddle/cinn/poly/naive_scheduler.h create mode 100755 paddle/cinn/poly/poly_scheduler.cc create mode 100644 paddle/cinn/poly/poly_scheduler.h create mode 100644 paddle/cinn/poly/poly_scheduler_test.cc create mode 100644 paddle/cinn/poly/schedule.cc create mode 100755 paddle/cinn/poly/schedule.h create mode 100755 paddle/cinn/poly/schedule_test.cc create mode 100644 paddle/cinn/poly/stage.cc create mode 100755 paddle/cinn/poly/stage.h create mode 100755 paddle/cinn/poly/stage_test.cc create mode 100755 paddle/cinn/pybind/CMakeLists.txt create mode 100644 paddle/cinn/pybind/backends.cc create mode 100644 paddle/cinn/pybind/bind.cc create mode 100644 paddle/cinn/pybind/bind.h create mode 100644 paddle/cinn/pybind/bind_utils.h create mode 100644 paddle/cinn/pybind/common.cc create mode 100755 paddle/cinn/pybind/framework.cc create mode 100644 paddle/cinn/pybind/frontend.cc create mode 100755 paddle/cinn/pybind/ir.cc create mode 100644 paddle/cinn/pybind/lang.cc create mode 100755 paddle/cinn/pybind/optim.cc create mode 100644 paddle/cinn/pybind/pe.cc create mode 100644 paddle/cinn/pybind/poly.cc create mode 100644 paddle/cinn/pybind/runtime.cc create mode 100644 paddle/cinn/pybind/utils.cc create mode 100644 paddle/cinn/runtime/CMakeLists.txt create mode 100755 paddle/cinn/runtime/buffer.cc create mode 100755 paddle/cinn/runtime/buffer.h create mode 100644 paddle/cinn/runtime/cinn_runtime.cc create mode 100755 paddle/cinn/runtime/cinn_runtime.h create mode 100644 paddle/cinn/runtime/cinn_runtime_test.cc create mode 100644 paddle/cinn/runtime/cinn_x86_device_impl.cc create mode 100644 paddle/cinn/runtime/cpu/CMakeLists.txt create mode 100644 paddle/cinn/runtime/cpu/cblas.cc create mode 100644 paddle/cinn/runtime/cpu/cblas.h create mode 100644 paddle/cinn/runtime/cpu/host_intrinsics.cc create mode 100644 paddle/cinn/runtime/cpu/host_intrinsics.h create mode 100644 paddle/cinn/runtime/cpu/host_intrinsics_test.cc create mode 100644 paddle/cinn/runtime/cpu/mkl_math.cc create mode 100644 paddle/cinn/runtime/cpu/mkl_math.h create mode 100644 paddle/cinn/runtime/cpu/mkl_math_test.cc create mode 100644 paddle/cinn/runtime/cpu/mkldnn_math.cc create mode 100644 paddle/cinn/runtime/cpu/mkldnn_math.h create mode 100644 paddle/cinn/runtime/cpu/mkldnn_math_test.cc create mode 100644 paddle/cinn/runtime/cpu/thread_backend.cc create mode 100644 paddle/cinn/runtime/cpu/thread_backend.h create mode 100644 paddle/cinn/runtime/cpu/use_extern_funcs.h create mode 100755 paddle/cinn/runtime/cuda/CMakeLists.txt create mode 100644 paddle/cinn/runtime/cuda/bfloat16.h create mode 100644 paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh create mode 100644 paddle/cinn/runtime/cuda/cublas_util.h create mode 100644 paddle/cinn/runtime/cuda/cuda_instrinsics_bfloat16.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_instrinsics_float16.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_intrinsics.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_module.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_module.h create mode 100644 paddle/cinn/runtime/cuda/cuda_module_test.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_util.cc create mode 100644 paddle/cinn/runtime/cuda/cuda_util.h create mode 100644 paddle/cinn/runtime/cuda/float16.h create mode 100644 paddle/cinn/runtime/cuda/test_util.h create mode 100644 paddle/cinn/runtime/cuda/use_extern_funcs.h create mode 100644 paddle/cinn/runtime/custom_function.cc create mode 100644 paddle/cinn/runtime/custom_function.h create mode 100644 paddle/cinn/runtime/custom_function_test.cc create mode 100644 paddle/cinn/runtime/flags.cc create mode 100644 paddle/cinn/runtime/flags.h create mode 100644 paddle/cinn/runtime/intrinsic.cc create mode 100644 paddle/cinn/runtime/intrinsic.h create mode 100644 paddle/cinn/runtime/intrinsic_types.cc create mode 100644 paddle/cinn/runtime/intrinsic_types.h create mode 100644 paddle/cinn/runtime/tiny_runtime.cc create mode 100644 paddle/cinn/runtime/use_extern_funcs.h create mode 100755 paddle/cinn/utils/CMakeLists.txt create mode 100644 paddle/cinn/utils/data_util.cc create mode 100644 paddle/cinn/utils/data_util.h create mode 100644 paddle/cinn/utils/dot_lang.cc create mode 100644 paddle/cinn/utils/dot_lang.h create mode 100644 paddle/cinn/utils/error.cc create mode 100644 paddle/cinn/utils/error.h create mode 100644 paddle/cinn/utils/event.cc create mode 100644 paddle/cinn/utils/event.h create mode 100644 paddle/cinn/utils/functional.cc create mode 100644 paddle/cinn/utils/functional.h create mode 100644 paddle/cinn/utils/functional_test.cc create mode 100644 paddle/cinn/utils/multi_threading.cc create mode 100644 paddle/cinn/utils/multi_threading.h create mode 100644 paddle/cinn/utils/multi_threading_test.cc create mode 100644 paddle/cinn/utils/profiler.cc create mode 100644 paddle/cinn/utils/profiler.h create mode 100644 paddle/cinn/utils/profiler_test.cc create mode 100644 paddle/cinn/utils/random_engine.cc create mode 100644 paddle/cinn/utils/random_engine.h create mode 100644 paddle/cinn/utils/registry.h create mode 100644 paddle/cinn/utils/sized_multi_set.cc create mode 100644 paddle/cinn/utils/sized_multi_set.h create mode 100644 paddle/cinn/utils/sized_multi_set_test.cc create mode 100644 paddle/cinn/utils/small_vector.cc create mode 100644 paddle/cinn/utils/small_vector.h create mode 100644 paddle/cinn/utils/string.cc create mode 100644 paddle/cinn/utils/string.h create mode 100644 paddle/cinn/utils/string_test.cc create mode 100644 paddle/cinn/utils/timer.cc create mode 100644 paddle/cinn/utils/timer.h create mode 100644 paddle/cinn/utils/type_defs.h create mode 100644 python/cinn/__init__.py create mode 100644 python/cinn/auto_schedule/__init__.py create mode 100644 python/cinn/auto_schedule/cost_model/__init__.py create mode 100644 python/cinn/auto_schedule/cost_model/cost_model.py create mode 100644 python/cinn/auto_schedule/cost_model/xgb_cost_model.py create mode 100644 python/cinn/backends.py create mode 100644 python/cinn/common.py create mode 100644 python/cinn/framework.py create mode 100644 python/cinn/frontend.py create mode 100644 python/cinn/ir/__init__.py create mode 100644 python/cinn/lang.py create mode 100644 python/cinn/libs/__init__.py create mode 100644 python/cinn/optim.py create mode 100644 python/cinn/pe.py create mode 100644 python/cinn/poly.py create mode 100644 python/cinn/runtime.py create mode 100644 python/cinn/utils.py create mode 100644 python/cinn/version/__init__.py create mode 100644 test/cinn/auto_schedule/cost_model/test_cost_model.py create mode 100644 test/cinn/conv2d_utils.py create mode 100644 test/cinn/fake_model/naive_mul.py create mode 100644 test/cinn/fake_model/naive_multi_fc.py create mode 100644 test/cinn/fake_model/resnet_model.py create mode 100644 test/cinn/fusion/fusion_test.py create mode 100644 test/cinn/fusion/test_cast_broadcast_reduce_max.py create mode 100644 test/cinn/fusion/test_reduce_cast.py create mode 100644 test/cinn/fusion/test_select_reduce.py create mode 100644 test/cinn/op_mappers/op_mapper_test.py create mode 100644 test/cinn/op_mappers/test_argmax_op.py create mode 100644 test/cinn/op_mappers/test_argmin_op.py create mode 100644 test/cinn/op_mappers/test_argsort_op.py create mode 100644 test/cinn/op_mappers/test_assign_value_op.py create mode 100644 test/cinn/op_mappers/test_atan2_op.py create mode 100644 test/cinn/op_mappers/test_batch_norm_op.py create mode 100644 test/cinn/op_mappers/test_bitwise_op.py create mode 100644 test/cinn/op_mappers/test_cholesky_op.py create mode 100644 test/cinn/op_mappers/test_clip_op.py create mode 100644 test/cinn/op_mappers/test_compare_op.py create mode 100644 test/cinn/op_mappers/test_conv2d_op.py create mode 100644 test/cinn/op_mappers/test_cumsum_op.py create mode 100644 test/cinn/op_mappers/test_elementwise_op.py create mode 100644 test/cinn/op_mappers/test_expand_op.py create mode 100644 test/cinn/op_mappers/test_expand_v2_op.py create mode 100644 test/cinn/op_mappers/test_fill_constant_op.py create mode 100644 test/cinn/op_mappers/test_flip_op.py create mode 100644 test/cinn/op_mappers/test_gather_nd_op.py create mode 100644 test/cinn/op_mappers/test_gather_op.py create mode 100644 test/cinn/op_mappers/test_gaussian_random_op.py create mode 100644 test/cinn/op_mappers/test_layer_norm_op.py create mode 100644 test/cinn/op_mappers/test_log1p_op.py create mode 100644 test/cinn/op_mappers/test_logical_op.py create mode 100644 test/cinn/op_mappers/test_lookup_table_op.py create mode 100644 test/cinn/op_mappers/test_mul_op.py create mode 100644 test/cinn/op_mappers/test_norm_op.py create mode 100644 test/cinn/op_mappers/test_one_hot_op.py create mode 100644 test/cinn/op_mappers/test_pool2d_op.py create mode 100644 test/cinn/op_mappers/test_pow_op.py create mode 100644 test/cinn/op_mappers/test_randint_op.py create mode 100644 test/cinn/op_mappers/test_reduce_op.py create mode 100644 test/cinn/op_mappers/test_reverse_op.py create mode 100644 test/cinn/op_mappers/test_roll_op.py create mode 100644 test/cinn/op_mappers/test_scale_op.py create mode 100644 test/cinn/op_mappers/test_scatter_op.py create mode 100644 test/cinn/op_mappers/test_sign_op.py create mode 100644 test/cinn/op_mappers/test_split_op.py create mode 100644 test/cinn/op_mappers/test_squeeze_op.py create mode 100644 test/cinn/op_mappers/test_stack_op.py create mode 100644 test/cinn/op_mappers/test_strided_slice_op.py create mode 100644 test/cinn/op_mappers/test_take_along_axis_op.py create mode 100644 test/cinn/op_mappers/test_tile_op.py create mode 100644 test/cinn/op_mappers/test_transpose2_op.py create mode 100644 test/cinn/op_mappers/test_triangular_solve_op.py create mode 100644 test/cinn/op_mappers/test_unary_op.py create mode 100644 test/cinn/op_mappers/test_uniform_random_op.py create mode 100644 test/cinn/op_mappers/test_where_op.py create mode 100755 test/cinn/ops/op_test.py create mode 100644 test/cinn/ops/op_test_helper.py create mode 100644 test/cinn/ops/test_abs_op.py create mode 100644 test/cinn/ops/test_acos_op.py create mode 100644 test/cinn/ops/test_add_op.py create mode 100644 test/cinn/ops/test_arange_op.py create mode 100644 test/cinn/ops/test_argsort_op.py create mode 100644 test/cinn/ops/test_asin_op.py create mode 100644 test/cinn/ops/test_asinh_op.py create mode 100644 test/cinn/ops/test_atan2_op.py create mode 100644 test/cinn/ops/test_atan_op.py create mode 100644 test/cinn/ops/test_atanh_op.py create mode 100644 test/cinn/ops/test_batch_norm_op.py create mode 100644 test/cinn/ops/test_binary_elementwise_op.py create mode 100644 test/cinn/ops/test_bitcast_convert_op.py create mode 100644 test/cinn/ops/test_bitwise_op.py create mode 100644 test/cinn/ops/test_broadcast_to_op.py create mode 100644 test/cinn/ops/test_broadcast_to_op_new.py create mode 100644 test/cinn/ops/test_cast_op.py create mode 100644 test/cinn/ops/test_cbrt_op.py create mode 100644 test/cinn/ops/test_ceil_op.py create mode 100644 test/cinn/ops/test_cholesky_op.py create mode 100644 test/cinn/ops/test_clz_op.py create mode 100644 test/cinn/ops/test_comparison_op.py create mode 100755 test/cinn/ops/test_concat_op.py create mode 100644 test/cinn/ops/test_constant_op.py create mode 100755 test/cinn/ops/test_conv2d_op.py create mode 100644 test/cinn/ops/test_cos_op.py create mode 100644 test/cinn/ops/test_cosh_op.py create mode 100644 test/cinn/ops/test_depthwise_conv2d_op.py create mode 100644 test/cinn/ops/test_divide_op.py create mode 100644 test/cinn/ops/test_dropout_infer_op.py create mode 100644 test/cinn/ops/test_erf_op.py create mode 100644 test/cinn/ops/test_exp_op.py create mode 100644 test/cinn/ops/test_expand_dims.py create mode 100644 test/cinn/ops/test_fill_constant_op.py create mode 100644 test/cinn/ops/test_floor_divide_op.py create mode 100644 test/cinn/ops/test_floor_op.py create mode 100644 test/cinn/ops/test_gather_nd_op.py create mode 100644 test/cinn/ops/test_gather_op.py create mode 100644 test/cinn/ops/test_gaussian_random_op.py create mode 100644 test/cinn/ops/test_gelu_op.py create mode 100644 test/cinn/ops/test_identity_op.py create mode 100644 test/cinn/ops/test_is_finite_op.py create mode 100644 test/cinn/ops/test_is_inf_op.py create mode 100644 test/cinn/ops/test_is_nan_op.py create mode 100644 test/cinn/ops/test_isclose_op.py create mode 100644 test/cinn/ops/test_left_shift_op.py create mode 100644 test/cinn/ops/test_log_op.py create mode 100644 test/cinn/ops/test_logical_right_shift_op.py create mode 100644 test/cinn/ops/test_lookup_table_op.py create mode 100755 test/cinn/ops/test_matmul_op.py create mode 100644 test/cinn/ops/test_max_op.py create mode 100644 test/cinn/ops/test_mod_op.py create mode 100755 test/cinn/ops/test_mul_op.py create mode 100644 test/cinn/ops/test_multiply_op.py create mode 100644 test/cinn/ops/test_negative_op.py create mode 100755 test/cinn/ops/test_one_hot_op.py create mode 100644 test/cinn/ops/test_pool2d_op.py create mode 100644 test/cinn/ops/test_popc_op.py create mode 100644 test/cinn/ops/test_pow_op.py create mode 100644 test/cinn/ops/test_randint_op.py create mode 100644 test/cinn/ops/test_reciprocal_op.py create mode 100644 test/cinn/ops/test_reduce_op.py create mode 100644 test/cinn/ops/test_reduce_op_new.py create mode 100644 test/cinn/ops/test_reduce_op_other.py create mode 100644 test/cinn/ops/test_relu6_op.py create mode 100755 test/cinn/ops/test_relu_op.py create mode 100644 test/cinn/ops/test_remainder_op.py create mode 100644 test/cinn/ops/test_repeat_op.py create mode 100644 test/cinn/ops/test_reshape_op.py create mode 100644 test/cinn/ops/test_resize_op.py create mode 100755 test/cinn/ops/test_reverse_op.py create mode 100644 test/cinn/ops/test_right_shift_op.py create mode 100644 test/cinn/ops/test_round_op.py create mode 100644 test/cinn/ops/test_rsqrt_op.py create mode 100644 test/cinn/ops/test_scale_op.py create mode 100644 test/cinn/ops/test_scatter_add.py create mode 100644 test/cinn/ops/test_scatter_assign_op.py create mode 100644 test/cinn/ops/test_select_op.py create mode 100644 test/cinn/ops/test_sigmoid_op.py create mode 100644 test/cinn/ops/test_sign_op.py create mode 100644 test/cinn/ops/test_sin_op.py create mode 100644 test/cinn/ops/test_sinh_op.py create mode 100644 test/cinn/ops/test_slice_assign_op.py create mode 100644 test/cinn/ops/test_slice_op.py create mode 100644 test/cinn/ops/test_softmax_op.py create mode 100644 test/cinn/ops/test_sort_op.py create mode 100755 test/cinn/ops/test_split_op.py create mode 100644 test/cinn/ops/test_sqrt_op.py create mode 100644 test/cinn/ops/test_squeeze_op.py create mode 100644 test/cinn/ops/test_subtract_op.py create mode 100644 test/cinn/ops/test_sum_op.py create mode 100644 test/cinn/ops/test_tan_op.py create mode 100644 test/cinn/ops/test_tanh_op.py create mode 100644 test/cinn/ops/test_top_k_op.py create mode 100644 test/cinn/ops/test_transpose_op.py create mode 100644 test/cinn/ops/test_triangular_solve_op.py create mode 100644 test/cinn/ops/test_trunc_op.py create mode 100644 test/cinn/ops/test_unary_elementwise_op.py create mode 100644 test/cinn/ops/test_uniform_random_op.py create mode 100644 test/cinn/ops/test_zero_dim_tensor.py create mode 100644 test/cinn/passes/pass_test.py create mode 100644 test/cinn/passes/test_auto_cast_pass.py create mode 100644 test/cinn/passes/test_expand_zero_dim_pass.py create mode 100644 test/cinn/passes/test_transpose_floding_input_pass.py create mode 100644 test/cinn/passes/test_transpose_floding_output_pass.py create mode 100644 test/cinn/pool_utils.py create mode 100644 test/cinn/test_common.py create mode 100755 test/cinn/test_computation.py create mode 100755 test/cinn/test_efficientnet.py create mode 100755 test/cinn/test_facedet.py create mode 100755 test/cinn/test_frontend.py create mode 100644 test/cinn/test_hlir_framework.py create mode 100644 test/cinn/test_ir.py create mode 100755 test/cinn/test_matmul.py create mode 100644 test/cinn/test_mobilenetv1.py create mode 100755 test/cinn/test_mobilenetv2.py create mode 100755 test/cinn/test_netbuilder.py create mode 100755 test/cinn/test_op_benchmark.py create mode 100644 test/cinn/test_op_broadcast.py create mode 100644 test/cinn/test_op_nn.py create mode 100644 test/cinn/test_op_transform.py create mode 100755 test/cinn/test_packed_func.py create mode 100644 test/cinn/test_paddle_model_convertor.py create mode 100644 test/cinn/test_pe_elementwise.py create mode 100644 test/cinn/test_pe_reduction.py create mode 100644 test/cinn/test_pe_transform.py create mode 100755 test/cinn/test_resnet.py create mode 100755 test/cinn/test_resnet18.py create mode 100755 test/cinn/test_resnet50.py create mode 100644 test/cinn/test_squeezenet.py create mode 100755 test/cinn/test_utils.py create mode 100644 test/cpp/cinn/CMakeLists.txt create mode 100755 test/cpp/cinn/benchmark/CMakeLists.txt create mode 100644 test/cpp/cinn/benchmark/test_all_ops_default.cc create mode 100644 test/cpp/cinn/benchmark/test_elementwise.cc create mode 100644 test/cpp/cinn/benchmark/test_elementwise.h create mode 100644 test/cpp/cinn/benchmark/test_matmul.cc create mode 100644 test/cpp/cinn/benchmark/test_matmul.h create mode 100755 test/cpp/cinn/benchmark/test_utils.cc create mode 100755 test/cpp/cinn/benchmark/test_utils.h create mode 100644 test/cpp/cinn/concrete_program_builder.h create mode 100644 test/cpp/cinn/program_builder.cc create mode 100644 test/cpp/cinn/program_builder.h create mode 100644 test/cpp/cinn/test01_elementwise_add_case.cc create mode 100644 test/cpp/cinn/test01_elementwise_add_main.cc create mode 100644 test/cpp/cinn/test02_helper.h create mode 100644 test/cpp/cinn/test02_matmul_case.cc create mode 100644 test/cpp/cinn/test02_matmul_main.cc create mode 100644 test/cpp/cinn/test03_convolution_case.cc create mode 100755 test/cpp/cinn/test03_convolution_main.cc diff --git a/cmake/cinn/external/jitify.cmake b/cmake/cinn/external/jitify.cmake index b04d64b12b8fb..8ee57c13ece4c 100644 --- a/cmake/cinn/external/jitify.cmake +++ b/cmake/cinn/external/jitify.cmake @@ -12,7 +12,6 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/NVIDIA/jitify.git" GIT_TAG 57de649139c866eb83acacfe50c92ad7c6278776 - GIT_TAG master PREFIX ${CINN_THIRD_PARTY_PATH}/jitify SOURCE_DIR ${JITIFY_SOURCE_PATH} CONFIGURE_COMMAND "" diff --git a/paddle/cinn/CMakeLists.txt b/paddle/cinn/CMakeLists.txt new file mode 100644 index 0000000000000..16c70714d7f36 --- /dev/null +++ b/paddle/cinn/CMakeLists.txt @@ -0,0 +1,21 @@ +if (WITH_TESTING) + cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest) +endif() + +add_subdirectory(auto_schedule) +add_subdirectory(common) +add_subdirectory(utils) +add_subdirectory(poly) +add_subdirectory(runtime) +add_subdirectory(ir) +add_subdirectory(backends) +add_subdirectory(lang) +add_subdirectory(optim) +add_subdirectory(hlir) +add_subdirectory(pybind) +add_subdirectory(frontend) + +# Download a model +download_and_uncompress("${DOWNLOAD_MODEL_DIR}" "${PADDLE_RESOURCE_URL}" "lite_naive_model.tar.gz") + +core_gather_headers() diff --git a/paddle/cinn/auto_schedule/CMakeLists.txt b/paddle/cinn/auto_schedule/CMakeLists.txt new file mode 100644 index 0000000000000..7a2d725d33ee8 --- /dev/null +++ b/paddle/cinn/auto_schedule/CMakeLists.txt @@ -0,0 +1,22 @@ +add_subdirectory(analysis) +add_subdirectory(cost_model) +add_subdirectory(database) +add_subdirectory(measure) +add_subdirectory(post_schedule_rule) +add_subdirectory(search_space) +add_subdirectory(search_strategy) +add_subdirectory(task) +add_subdirectory(task_scheduler) +add_subdirectory(tests) + +proto_library(auto_schedule_proto SRCS auto_schedule.proto DEPS schedule_desc_proto) + +core_gather_headers() + +gather_srcs(cinnapi_src SRCS auto_tuner.cc) + +#cc_test(test_auto_tuner SRCS auto_tuner_test.cc DEPS cinncore) + +foreach(header ${auto_schedule_proto_HDRS}) + set(core_proto_includes "${core_proto_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/cinn/auto_schedule/analysis/CMakeLists.txt b/paddle/cinn/auto_schedule/analysis/CMakeLists.txt new file mode 100644 index 0000000000000..46eda4a587bb8 --- /dev/null +++ b/paddle/cinn/auto_schedule/analysis/CMakeLists.txt @@ -0,0 +1,5 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS analyze_ir.cc) + +cc_test(test_analyze_ir SRCS analyze_ir_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir.cc new file mode 100644 index 0000000000000..21ff620118d59 --- /dev/null +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir.cc @@ -0,0 +1,176 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/analysis/analyze_ir.h" + +#include + +#include +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/lower.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/optimize.h" +#include "cinn/optim/transform_gpu_forloop.h" + +namespace cinn { +namespace auto_schedule { + +std::vector IndicesToVars(const std::vector& indices) { + std::vector result; + for (const ir::Expr& e : indices) { + // Whether we have to convert other types, like const numbers to Var? + if (e.As() != nullptr) { + ir::Expr copy_e = optim::IRCopy(e); + ir::_Var_* var_ref = copy_e.As(); + result.emplace_back(ir::Var(var_ref)); + } + } + return result; +} + +void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) { + if (!sche_block->read_buffers.empty() || !sche_block->write_buffers.empty()) { + return; + } + + ir::CollectIRNodesWithoutTensor(sche_block->body, [&](const Expr* x) { + const ir::Load* load_expr = x->As(); + if (load_expr != nullptr) { + const ir::Tensor t = load_expr->tensor.as_tensor_ref(); + sche_block->read_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices))); + return false; + } + const ir::Store* store_expr = x->As(); + if (store_expr != nullptr) { + const ir::Tensor t = store_expr->tensor.as_tensor_ref(); + sche_block->write_buffers.emplace_back(ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices))); + return false; + } + return false; + }); +} + +bool ContainsNodeType(ir::Expr expr, const std::unordered_set& node_types) { + std::set collection = ir::CollectIRNodesWithoutTensor( + expr, [&](const Expr* x) { return node_types.find(x->node_type()) != node_types.end(); }); + return !collection.empty(); +} + +std::unordered_set GetOutputNamesFromLoweredFunc(const std::vector& lowered_funcs) { + std::unordered_set result; + for (const ir::LoweredFunc& func : lowered_funcs) { + for (const ir::Argument& arg : func->args) { + if (arg.is_output()) { + result.insert(arg.name()); + } + } + } + return result; +} + +bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize) { + const ir::ScheduleBlock* sche_block = sche_block_realize.schedule_block.As(); + if (sche_block->write_buffers.size() != 1 || sche_block->read_buffers.empty()) { + return false; + } + const ir::Expr& write_buffer = sche_block->write_buffers[0].As()->buffer; + + // Enumerate each read region, get the number of schedule block iter vars + // which are not used to index the read region + int total_unused_iter_vars = 0; + + for (const ir::Expr& read_buffer_expr : sche_block->read_buffers) { + const ir::_BufferRange_* read_buffer = read_buffer_expr.As(); + // Skip the reduction buffer + if (read_buffer->buffer == write_buffer) { + continue; + } + // Collect the vars in schedule block that are used to index the read region + std::unordered_set vars_index_read; + for (const Var& range : read_buffer->ranges) { + vars_index_read.insert(range->name); + } + // Check the block iter vars are not used to index the read region + int n_unused_block_vars = 0; + for (const ir::Var& block_iter_var : sche_block->iter_vars) { + if (!block_iter_var->is_reduce_axis) { + bool iter_var_in_read = false; + for (const std::string& var : vars_index_read) { + if (var == block_iter_var->name) { + iter_var_in_read = true; + break; + } + } + if (!iter_var_in_read) { + ++n_unused_block_vars; + } + } + } + total_unused_iter_vars += n_unused_block_vars; + } + + return total_unused_iter_vars >= 1; +} + +ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body) { + ir::ModuleExpr mod_expr(std::vector({body})); + ir::IRSchedule ir_sch(mod_expr); + + // temp_bufs may be deleted during auto tuning (such as auto inline), + // we have to check from old temp bufs and set them as local buffer. + for (const ir::Buffer& buf : old_func->temp_bufs) { + const std::string& buf_name = buf->name; + std::vector all_block_realizes = ir_sch.GetAllBlocks(); + for (ir::Expr& e : all_block_realizes) { + const ir::ScheduleBlockRealize* sche_block_realize = e.As(); + const std::string& sche_name = sche_block_realize->schedule_block.As()->name; + if (buf_name == "_" + sche_name) { + VLOG(6) << "Set local buffer for temp buffer " << buf_name; + ir_sch.SetBuffer(e, "local", true); + break; + } + } + } + + ir::Expr updated_body = ir_sch.GetModule().GetExprs()[0]; +#ifdef CINN_WITH_CUDA + optim::OptimizeExprGPU(&updated_body); +#endif + + // Get new temp bufs by analyzing. + std::vector new_temp_bufs = lang::GetTempBuffers(old_func->args, updated_body); + ir::LoweredFunc new_func = ir::_LoweredFunc_::Make(old_func->name, old_func->args, updated_body, new_temp_bufs); +#ifdef CINN_WITH_CUDA + if (target == common::DefaultNVGPUTarget()) { + new_func->PrepareCudaAxisInfoFromBody(); + } +#endif + new_func = optim::Optimize(Expr(new_func), target, false).as_lowered_func_ref(); + new_func->PrepareBufferCastExprs(/*with_expr_gen_tensor = */ false); + + return new_func; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir.h b/paddle/cinn/auto_schedule/analysis/analyze_ir.h new file mode 100644 index 0000000000000..f2d214db89e43 --- /dev/null +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/lowered_func.h" + +namespace cinn { +namespace auto_schedule { + +void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block); + +bool ContainsNodeType(ir::Expr expr, const std::unordered_set& node_types); + +/** + * Collects all input lowered_funcs and return names of all output arguments + */ +std::unordered_set GetOutputNamesFromLoweredFunc(const std::vector& lowered_funcs); + +/** + * Determine whether a schedule block needs multileveltiling + */ +bool NeedsMultiLevelTiling(const ir::ScheduleBlockRealize& sche_block_realize); + +/** + * Update a LoweredFunc by regenerating related fields with a new function body + */ +ir::LoweredFunc UpdateFuncWithNewBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc new file mode 100644 index 0000000000000..e51bd0e94cf26 --- /dev/null +++ b/paddle/cinn/auto_schedule/analysis/analyze_ir_test.cc @@ -0,0 +1,181 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/analysis/analyze_ir.h" + +#include +#include + +#include +#include + +#include "cinn/common/context.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace auto_schedule { + +TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_SimpleAssign) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr M(32); + ir::Expr N(32); + + lang::Placeholder A("A", {M, N}); + ir::Tensor B = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + + ASSERT_FALSE(funcs.empty()); + ir::Expr ast_expr = funcs[0]->body; + + VLOG(6) << "Analyzing for Expr:"; + VLOG(6) << ast_expr; + + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + std::vector all_block_realizes = ir_sch.GetAllBlocks(); + ASSERT_EQ(all_block_realizes.size(), 1UL); + + ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As(); + ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + AnalyzeScheduleBlockReadWriteBuffer(sche_block); + + /* + * the sche_block_realize will be: + * ScheduleBlock(B) + * { + * i0, i1 = axis.bind(i, j) + * read_buffers(_A[i0(undefined:undefined), i1(undefined:undefined)]) + * write_buffers(_B[i0(undefined:undefined), i1(undefined:undefined)]) + * B[i0, i1] = A[i0, i1] + * } + */ + + VLOG(6) << "ScheduleBlockRealize: "; + VLOG(6) << all_block_realizes[0]; + + ASSERT_EQ(sche_block->read_buffers.size(), 1UL); + + std::stringstream read_ss; + read_ss << sche_block->read_buffers[0]; + ASSERT_EQ(read_ss.str(), "_A[i0(0:32), i1(0:32)]"); + + ASSERT_EQ(sche_block->write_buffers.size(), 1UL); + std::stringstream write_ss; + write_ss << sche_block->write_buffers[0]; + ASSERT_EQ(write_ss.str(), "_B[i0(0:32), i1(0:32)]"); +} + +TEST(AnalyzeIr, AnalyzeScheduleBlockReadWriteBuffer_AddDiffShape) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr M(32); + ir::Expr N(128); + + lang::Placeholder A("A", {M}); + lang::Placeholder B("B", {N}); + + ir::Tensor C = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); + + poly::StageMap stages = poly::CreateStages({C}); + std::vector funcs = lang::LowerVec("AddDiffShape", stages, {C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr before MultiLevelTiling: "; + VLOG(6) << ast_expr; + + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + std::vector all_block_realizes = ir_sch.GetAllBlocks(); + ASSERT_EQ(all_block_realizes.size(), 1UL); + + ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes[0].As(); + ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + AnalyzeScheduleBlockReadWriteBuffer(sche_block); + + VLOG(6) << "ScheduleBlockRealize: "; + VLOG(6) << all_block_realizes[0]; + ASSERT_EQ(sche_block->read_buffers.size(), 2UL); + std::vector expect_read = {"_A[i0(0:32)]", "_B[i1(0:128)]"}; + + ASSERT_EQ(sche_block->read_buffers.size(), expect_read.size()); + for (size_t i = 0; i < expect_read.size(); ++i) { + std::stringstream read_ss; + read_ss << sche_block->read_buffers[i]; + ASSERT_EQ(read_ss.str(), expect_read[i]); + } + + ASSERT_EQ(sche_block->write_buffers.size(), 1UL); + std::stringstream write_ss; + write_ss << sche_block->write_buffers[0]; + ASSERT_EQ(write_ss.str(), "_C[i0(0:32), i1(0:128)]"); +} + +TEST(AnalyzeIr, ContainsNodeType) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr M(32); + ir::Expr N(32); + + lang::Placeholder A("A", {M, N}); + ir::Tensor B = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + + ASSERT_FALSE(funcs.empty()); + ir::Expr ast_expr = funcs[0]->body; + + VLOG(6) << "Analyzing for Expr:"; + VLOG(6) << ast_expr; + + ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::Store})); + ASSERT_TRUE(ContainsNodeType(ast_expr, {ir::IrNodeTy::Load, ir::IrNodeTy::IfThenElse})); + ASSERT_FALSE(ContainsNodeType(ast_expr, {ir::IrNodeTy::IfThenElse, ir::IrNodeTy::Sum})); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/auto_schedule.proto b/paddle/cinn/auto_schedule/auto_schedule.proto new file mode 100644 index 0000000000000..d5d8eff373fa3 --- /dev/null +++ b/paddle/cinn/auto_schedule/auto_schedule.proto @@ -0,0 +1,26 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax ="proto3"; + +package cinn.auto_schedule.proto; + +import "cinn/ir/schedule_desc.proto"; + +message TuningRecord { + string task_key = 1; + double execution_cost = 2; + double predicted_cost = 3; + cinn.ir.proto.ScheduleDesc trace = 4; +} diff --git a/paddle/cinn/auto_schedule/auto_tuner.cc b/paddle/cinn/auto_schedule/auto_tuner.cc new file mode 100644 index 0000000000000..86baae7007a56 --- /dev/null +++ b/paddle/cinn/auto_schedule/auto_tuner.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/auto_tuner.h" + +#include +#include + +#include +#include +#include + +#include "cinn/auto_schedule/database/jsonfile_database.h" +#include "cinn/auto_schedule/measure/schedule_measurer.h" +#include "cinn/auto_schedule/measure/simple_builder.h" +#include "cinn/auto_schedule/measure/simple_runner.h" +#include "cinn/auto_schedule/task/task_creator.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" +#include "cinn/common/context.h" +#include "cinn/common/type.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/visualize_helper.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace auto_schedule { + +AutoTuner::AutoTuner(const common::Target& target, hlir::framework::Graph* graph) : target_(target), graph_(graph) {} + +void AutoTuner::Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler) { + // create builder, runner, and schedule measurer + builder_ = std::make_unique(graph_compiler); + runner_ = std::make_unique(config.runner_repeat_times); + schedule_measurer_ = std::make_unique(builder_.get(), runner_.get()); + + // initialize database + database_ = std::move(Database::Make(config.database_config)); + + // create tasks + TaskCreator task_creator; + tasks_ = task_creator.CreateTuneTaskOpLevel(graph_); + + const auto& dtype_dict = graph_->GetAttrs>("inferdtype"); + const auto& shape_dict = graph_->GetAttrs>("infershape"); + + op_lowerer_ = std::make_unique(dtype_dict, shape_dict, target_); + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + for (auto i = 0; i < tasks_.size(); ++i) { + auto&& task = tasks_[i]; + task.Initialize(shape_dict, dtype_dict, op_lowerer_.get()); + // Register the initial ModuleExpr corresponding to the task + task_registry->Regist(task.serialized_key, ir::ModuleExpr(task.GetLoweredFuncBodyExprs())); + VLOG(3) << "Add a task, id:" << i << ", serialized_key:\n" << task.serialized_key; + } + + // create task optimizers + utils::LinearRandomEngine::StateType initial_seed = utils::LinearRandomEngine::GetDeviceRandomValue(); + task_optimizers_.resize(tasks_.size()); + std::transform(tasks_.begin(), tasks_.end(), task_optimizers_.begin(), [&](TuneTask& task) { + return std::make_unique( + &task, schedule_measurer_.get(), database_.get(), utils::ForkRandomState(&initial_seed)); + }); + + // create task scheduler + task_scheduler_ = TaskScheduler::Make(tasks_, config.task_schedule_config, config.task_schedule_strategy); +} + +void PrintResult(std::shared_ptr group) { + if (!VLOG_IS_ON(3)) { + return; + } + + auto nodes = group->CollectNodes(); + VLOG(3) << "Node size:" << nodes.size(); + VLOG(3) << "Group {"; + for (auto* node : nodes) { + VLOG(3) << " " << hlir::framework::DebugString(node); + } + VLOG(3) << "}"; +} + +void PrintResult(const FunctionGroup& functions) { + if (!VLOG_IS_ON(3)) { + return; + } + + VLOG(3) << "Function size:" << functions.size(); + for (auto i = 0; i < functions.size(); ++i) { + const ir::LoweredFunc& func = functions.at(i); + VLOG(3) << "LoweredFunc-" << i << " detail:\n" << func; + } +} + +void PrintResult(const TuningResult& result) { + if (!VLOG_IS_ON(3)) { + return; + } + VLOG(3) << "###### Debug TuningResult ######\n"; + VLOG(3) << "Tuned SubGraph num:" << result.subgraphs.size(); + for (auto i = 0; i < result.subgraphs.size(); ++i) { + VLOG(3) << "****** SubGraph-" << i << " Detail ******\n"; + PrintResult(result.subgraphs.at(i)); + VLOG(3) << "****** SubGraph End ******"; + } + + VLOG(3) << "Tuned FunctionGroup num:" << result.function_groups.size(); + for (auto i = 0; i < result.function_groups.size(); ++i) { + VLOG(3) << "****** FunctionGroup-" << i << " Detail ******\n"; + PrintResult(result.function_groups.at(i)); + VLOG(3) << "****** FunctionGroup End ******"; + } + VLOG(3) << "###### TuningResult End ######"; +} + +TuningResult AutoTuner::Tune(const TuningOptions& options) { + CHECK_GT(options.num_tuning_rounds, 0) << "Invalid config"; + VLOG(3) << "Begin tuning with round num=" << options.num_tuning_rounds << ", tasks size=" << tasks_.size(); + + TuningResult result; + result.subgraphs.resize(tasks_.size()); + result.function_groups.resize(tasks_.size()); + // A task only tunes schedule now, so we populate its sub_graph + // as default result of graph tuning, and that should be updated + // once we support graph tuning. + for (auto i = 0; i < tasks_.size(); ++i) { + auto&& task = tasks_.at(i); + result.subgraphs[i] = task.subgraph; + } + + for (int r = 0; r < options.num_tuning_rounds; ++r) { + VLOG(3) << "<<<<<< Round " << r << " >>>>>>"; + int run_id = -1; + task_scheduler_->Reset(); + while ((run_id = task_scheduler_->NextTaskId()) != -1) { + VLOG(3) << "Start tuning Task-" << run_id; + auto* opt = task_optimizers_.at(run_id).get(); + auto function_group = opt->Optimize(options); + VLOG(3) << "Task-" << run_id << " finished, print optimized functions:\n"; + PrintResult(function_group); + // update the best schedules searched so far. + result.function_groups.at(run_id) = std::move(function_group); + } + } + + PrintResult(result); + return result; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/auto_tuner.h b/paddle/cinn/auto_schedule/auto_tuner.h new file mode 100644 index 0000000000000..6a356bd3dd7b1 --- /dev/null +++ b/paddle/cinn/auto_schedule/auto_tuner.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/auto_schedule/measure/schedule_measurer.h" +#include "cinn/auto_schedule/task/task_optimizer.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" +#include "cinn/auto_schedule/tuning.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/op_lowering.h" + +namespace cinn { +namespace auto_schedule { + +// This class is entrance of auto-tune, users can use it +// to tune graph (not supported yet) and search a series of schedules +// that maybe more likely to obtain better performance. +// Internally, it creates necessary components and use them to perform tuning. +class AutoTuner { + public: + // configure how to perform auto-tune, such as + // the way to create tasks, the strategy of scheduling tasks and so on. + struct Config { + std::string task_schedule_strategy = "round_robin"; + TaskScheduler::Config task_schedule_config; + int runner_repeat_times = 1; + DatabaseConfig database_config; + }; + + AutoTuner(const common::Target& target, hlir::framework::Graph* graph); + + // Initialize tuner with specific config and auxiliary objects. + void Initialize(const Config& config, hlir::framework::GraphCompiler* graph_compiler); + + // Perform the tuning process and return the final result + TuningResult Tune(const TuningOptions& options); + + private: + const common::Target& target_; + hlir::framework::Graph* graph_; + std::unique_ptr op_lowerer_; + + // Tasks to tune + std::vector tasks_; + // Scheduler that select a task to tune at every turn. + std::unique_ptr task_scheduler_; + // The actor to perform auto-tune, each optimizer take a task. + std::vector> task_optimizers_; + + // Classes used to measure AutoTune samples + std::unique_ptr builder_; + std::unique_ptr runner_; + std::unique_ptr schedule_measurer_; + + // The database to store tuning record + std::unique_ptr database_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/auto_tuner_test.cc b/paddle/cinn/auto_schedule/auto_tuner_test.cc new file mode 100644 index 0000000000000..362a279e852d1 --- /dev/null +++ b/paddle/cinn/auto_schedule/auto_tuner_test.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/auto_tuner.h" + +#include +#include + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/optimize.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/ir/ir_base.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(auto_schedule_use_cost_model); +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::BuildScope; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::GraphCompiler; +using ::cinn::hlir::framework::Instruction; +using ::cinn::hlir::framework::Node; +using ::cinn::hlir::framework::Scope; + +class TestAutoTuner : public ::testing::Test { + public: +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + std::shared_ptr graph; + std::shared_ptr compiled_scope; + std::unique_ptr graph_compiler; + std::unique_ptr tuner; + + frontend::Program CreateAddReluProgram() { + frontend::NetBuilder builder("test"); + + auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A"); + auto b = builder.CreateInput(Float(32), {64}, "B"); + auto c = builder.Add(a, b, 1); + auto d = builder.Relu(c); + + return builder.Build(); + } + + void SetUp() override { + srand(0); + // AutoTuner is combined with new IR Schedule + FLAGS_cinn_ir_schedule = true; + std::unordered_set fetch_ids; + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + compiled_scope = BuildScope(target, graph); + graph_compiler = std::make_unique(target, compiled_scope, graph); + tuner = std::make_unique(target, graph.get()); + } + + TuningResult InitializeAndTune(const AutoTuner::Config& config, const TuningOptions& options) { + tuner->Initialize(config, graph_compiler.get()); + return tuner->Tune(options); + } + + virtual void BasicCheckResult(const TuningResult& result) { + ASSERT_EQ(1, result.subgraphs.size()); + auto nodes = result.subgraphs.front()->CollectNodes(); + ASSERT_EQ(nodes.size(), 4UL); + ASSERT_EQ(nodes[0]->op()->name, "broadcast_to"); + ASSERT_EQ(nodes[1]->op()->name, "fill_constant"); + ASSERT_EQ(nodes[2]->op()->name, "elementwise_add"); + ASSERT_EQ(nodes[3]->op()->name, "max"); + + ASSERT_EQ(result.function_groups.size(), 1UL); + ASSERT_EQ(result.function_groups[0].size(), 1UL); + } + + virtual void ApplyTunedAndRun(const TuningResult& result) { + // build runtime program with tuning result + GraphCompiler::CompileOptions compile_options; + compile_options.with_instantiate_variables = true; + compile_options.Apply(result); + ASSERT_EQ(1, compile_options.groups.size()); + ASSERT_EQ(1, compile_options.lowered_funcs.size()); + VLOG(6) << "Print lowered_funcs before building"; + VLOG(6) << compile_options.lowered_funcs[0][0]; + VLOG(6) << compile_options.lowered_funcs[1][0]; + auto runtime_program = graph_compiler->Build(compile_options).runtime_program; + ASSERT_EQ(1, runtime_program->size()); + runtime_program->Execute(); + } + + void ZeroMeasure() { + // set config and options + AutoTuner::Config tuning_config; + tuning_config.task_schedule_strategy = "round_robin"; + + TuningOptions tuning_options; + tuning_options.num_measure_trials = 0; + auto result = InitializeAndTune(tuning_config, tuning_options); + BasicCheckResult(result); + ApplyTunedAndRun(result); + } + + void NonZeroMeasure() { + // set config and options + AutoTuner::Config tuning_config; + tuning_config.task_schedule_strategy = "round_robin"; + + TuningOptions tuning_options; + tuning_options.num_measure_trials = 4; + tuning_options.num_samples_per_iteration = 2; + + auto result = InitializeAndTune(tuning_config, tuning_options); + BasicCheckResult(result); + ApplyTunedAndRun(result); + } +}; + +TEST_F(TestAutoTuner, ZeroMeasure_DisableCostModel) { + FLAGS_auto_schedule_use_cost_model = false; + ZeroMeasure(); +} + +TEST_F(TestAutoTuner, ZeroMeasure_EnableCostModel) { + FLAGS_auto_schedule_use_cost_model = true; + ZeroMeasure(); +} + +TEST_F(TestAutoTuner, NonZeroMeasure_DisableCostModel) { + FLAGS_auto_schedule_use_cost_model = false; + NonZeroMeasure(); +} + +TEST_F(TestAutoTuner, NonZeroMeasure_EnableCostModel) { + FLAGS_auto_schedule_use_cost_model = true; + NonZeroMeasure(); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt b/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt new file mode 100644 index 0000000000000..6e52f7a3dad14 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/CMakeLists.txt @@ -0,0 +1,7 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS xgb_cost_model.cc expr_cost_model.cc feature.cc feature_extractor.cc) + +cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore) +cc_test(test_feature_extractor SRCS feature_extractor_test.cc DEPS cinncore) +cc_test(test_feature SRCS feature_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc new file mode 100644 index 0000000000000..e41a71a409109 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/cost_model/feature.h" +#include "cinn/auto_schedule/cost_model/feature_extractor.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const { + if (trained_times_.load() == 0) { + return SearchState::NOT_INIT_COST; + } + FeatureExtractor extractor; + Feature feature = extractor.Extract(sample, target); + std::vector feature_numbers = feature.ToFixedSizeVector(); + std::vector pred = XgbCostModel::Predict({feature_numbers}); + return pred[0]; +} + +void ExprCostModel::Train(const std::vector& samples, + const std::vector& labels, + const common::Target& target) { + trained_times_.store(1); + size_t total_size = samples.size(); + CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + std::vector> train_feature_numbers(total_size); + FeatureExtractor extractor; + for (size_t i = 0; i < total_size; ++i) { + CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; + Feature feature = extractor.Extract(*samples[i], target); + train_feature_numbers[i] = feature.ToFixedSizeVector(); + } + + XgbCostModel::Train(train_feature_numbers, labels); +} + +void ExprCostModel::Update(const std::vector& samples, + const std::vector& labels, + const common::Target& target) { + ++trained_times_; + size_t total_size = samples.size(); + CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels"; + std::vector> train_feature_numbers(total_size); + FeatureExtractor extractor; + for (size_t i = 0; i < total_size; ++i) { + CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr"; + Feature feature = extractor.Extract(*samples[i], target); + train_feature_numbers[i] = feature.ToFixedSizeVector(); + } + + XgbCostModel::Update(train_feature_numbers, labels); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h new file mode 100644 index 0000000000000..176424c785cb0 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/expr_cost_model.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * A C++ cost model which trains and predicts on ir::Expr + * + */ +class ExprCostModel : public XgbCostModel { + public: + virtual float Predict(const ir::ModuleExpr& sample, const common::Target& target) const; + void Train(const std::vector& samples, + const std::vector& labels, + const common::Target& target); + void Update(const std::vector& samples, + const std::vector& labels, + const common::Target& target); + + private: + std::atomic trained_times_{0}; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature.cc b/paddle/cinn/auto_schedule/cost_model/feature.cc new file mode 100644 index 0000000000000..1c7f8158eb409 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature.h" + +#include + +#include + +#include "cinn/common/target.h" + +namespace cinn { +namespace auto_schedule { + +Feature::Feature() + : target_(common::UnkTarget()), + stack_encoded_feature_(1), // initialize a LoopBlockFeature as root block + current_loop_block_index_(0), + parent_indices_(1, -1) {} + +Feature::Feature(const common::Target& target) + : target_(target), + stack_encoded_feature_(1), // initialize a LoopBlockFeature as root block + current_loop_block_index_(0), + parent_indices_(1, -1) {} + +std::vector Feature::ToFixedSizeVector() { + std::vector ret(LoopBlockFeature::kTotalSize + 1, 0); // LoopBlockFeature::kTotalSize plus 1 for target + + if (target_ == common::DefaultNVGPUTarget()) { + ret[0] = 1; + } // else 0 for other cases + + // loop[i] feature count should multiply iter_multi_num[i] + std::vector iter_multi_num; + for (size_t i = 0; i < stack_encoded_feature_.size(); ++i) { + int j = 1; + const LoopBlockFeature& loop_feature = stack_encoded_feature_[i]; + int loop_prod = 1; + int parent_prod = 1; + if (i != 0) { + parent_prod = iter_multi_num[parent_indices_[i]]; + loop_prod = parent_prod * loop_feature.loop_length; + } + iter_multi_num.push_back(loop_prod); + + ret[j] += (loop_feature.float_add_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.float_mul * loop_prod); + ++j; + ret[j] += (loop_feature.float_div_or_mod * loop_prod); + ++j; + ret[j] += (loop_feature.float_cmp * loop_prod); + ++j; + ret[j] += (loop_feature.float_math_func * loop_prod); + ++j; + ret[j] += (loop_feature.float_other_call * loop_prod); + ++j; + + ret[j] += (loop_feature.int_add_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.int_mul * loop_prod); + ++j; + ret[j] += (loop_feature.int_div_or_mod * loop_prod); + ++j; + ret[j] += (loop_feature.int_cmp * loop_prod); + ++j; + ret[j] += (loop_feature.int_math_func * loop_prod); + ++j; + ret[j] += (loop_feature.int_other_call * loop_prod); + ++j; + + ret[j] += (loop_feature.bool_op * loop_prod); + ++j; + ret[j] += (loop_feature.select_op * loop_prod); + ++j; + + ret[j] += (loop_feature.mem_alloc * loop_prod); + ++j; + ret[j] += (loop_feature.mem_free * loop_prod); + ++j; + ret[j] += (loop_feature.mem_read * loop_prod); + ++j; + ret[j] += (loop_feature.mem_write * loop_prod); + ++j; + + ret[j] += (loop_feature.float_reduce_sum_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_mul * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_div * loop_prod); + ++j; + ret[j] += (loop_feature.float_reduce_max_or_min * loop_prod); + ++j; + ret[j] += (loop_feature.float_broadcast * loop_prod); + ++j; + + ret[j] += (loop_feature.int_reduce_sum_or_sub * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_mul * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_div * loop_prod); + ++j; + ret[j] += (loop_feature.int_reduce_max_or_min * loop_prod); + ++j; + ret[j] += (loop_feature.int_broadcast * loop_prod); + ++j; + + ret[j + static_cast(loop_feature.loop_opt_type)] += 1; + j += LoopBlockFeature::kOptApplySize; + + ret[j] += (loop_feature.len_blockIdx_x * parent_prod); + ++j; + ret[j] += (loop_feature.len_blockIdx_y * parent_prod); + ++j; + ret[j] += (loop_feature.len_blockIdx_z * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_x * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_y * parent_prod); + ++j; + ret[j] += (loop_feature.len_threadIdx_z * parent_prod); + ++j; + ret[j] += (loop_feature.len_vthread * parent_prod); + ++j; + ret[j] += (loop_feature.vectorize_factor * parent_prod); + ++j; + } + + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = slog(ret[i]); + } + + return ret; +} + +void Feature::IntoLoopBlock() { + stack_encoded_feature_.emplace_back(LoopBlockFeature()); + stack_encoded_feature_[current_loop_block_index_].num_sub_loops += 1; + parent_indices_.push_back(current_loop_block_index_); + current_loop_block_index_ = stack_encoded_feature_.size() - 1; +} + +void Feature::ExitLoopBlock() { current_loop_block_index_ = parent_indices_[current_loop_block_index_]; } + +LoopBlockFeature& Feature::CurrentLoopBlock() { return stack_encoded_feature_[current_loop_block_index_]; } + +const LoopBlockFeature& Feature::CurrentLoopBlock() const { return stack_encoded_feature_[current_loop_block_index_]; } + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature.h b/paddle/cinn/auto_schedule/cost_model/feature.h new file mode 100644 index 0000000000000..019bd25382432 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature.h @@ -0,0 +1,178 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/* Loop feature enums */ +enum class ForOptimizeFeatureEnum : int { kNone, kGpuBind, kParallel, kUnroll, kVectorize }; + +/* function to scale feature numbers */ +inline float slog(float x) { return x < 0 ? std::log2(-x + 1) : std::log2(x + 1); } + +class LoopBlockFeature { + public: + // TODO(zhhsplendid): distinguish more types such as float16, float32, + // float64, etc. However speed the gap between float and int are larger than + // different bits, so we just distinguished int and float here + /* Arithmetic features */ + int float_add_or_sub = 0; + int float_mul = 0; + int float_div_or_mod = 0; + int float_cmp = 0; + int float_math_func = 0; + int float_other_call = 0; // like simple assign, cast, etc. + + int int_add_or_sub = 0; + int int_mul = 0; + int int_div_or_mod = 0; + int int_cmp = 0; + int int_math_func = 0; + int int_other_call = 0; // like simple assign, cast, etc. + + int bool_op = 0; + int select_op = 0; + + static constexpr int kArithSize = 6 * 2 + 2; + + /** + * Buffer memory features, which is the number of memory operations. + * Note that different size of memory operation can have various speed, + * however the speed difference would be small in OS. A meticulous TODO + * may be collect operand sizes (like alloc size, write size, or so) + */ + int mem_alloc = 0; + int mem_free = 0; + int mem_read = 0; + int mem_write = 0; + + static constexpr int kMemSize = 4; + + /** + * Reduce and Broadcast features + */ + int float_reduce_sum_or_sub = 0; + int float_reduce_mul = 0; + int float_reduce_div = 0; + int float_reduce_max_or_min = 0; + int float_broadcast = 0; + + int int_reduce_sum_or_sub = 0; + int int_reduce_mul = 0; + int int_reduce_div = 0; + int int_reduce_max_or_min = 0; + int int_broadcast = 0; + + static constexpr int kReduceBroadcastSize = 10; + + /* Loop type features */ + + // A TODO maybe add loop position (Inner, Outer, Middle) feature + + ForOptimizeFeatureEnum loop_opt_type = ForOptimizeFeatureEnum::kNone; + + static constexpr int kOptApplySize = 5; + + /* Thread features if loop is optimized by GPU or CPU parallelism. + * Useless in other cases. + */ + int len_blockIdx_x = 0; + int len_blockIdx_y = 0; + int len_blockIdx_z = 0; + int len_threadIdx_x = 0; + int len_threadIdx_y = 0; + int len_threadIdx_z = 0; + int len_vthread = 0; // length of virtual thread + int vectorize_factor = 0; + + static constexpr int kThreadFeatureSize = 8; + + static constexpr int kTotalSize = kArithSize + kMemSize + kReduceBroadcastSize + kOptApplySize + kThreadFeatureSize; + + /* Non-feature attributes, used to maintain during feature_extractor */ + + // Number to indicate the loop block inside current one + int num_sub_loops = 0; + + // Number of repeats of this loop, -1 represents unknown + int loop_length = 1; +}; + +/** + * Feature of Expr. It is used in CostModel + */ +class Feature { + public: + Feature(); + + Feature(const common::Target& target); + + // Convert the various-length loop block features to fixed-size vector + std::vector ToFixedSizeVector(); + + // Call when visit into a loop block to collect LoopBlockFeature + void IntoLoopBlock(); + // Call when exit a loop block to collect LoopBlockFeature + void ExitLoopBlock(); + // The current loop block which we should collect feature on + LoopBlockFeature& CurrentLoopBlock(); + // The current loop block which we should collect feature on + const LoopBlockFeature& CurrentLoopBlock() const; + + private: + // We treat a computation feature to be encoded as variable-length vector. + // The root compute block is not a loop, but we treat it as a size-1 loop. + // Blocks are encoded like a stack. Each LoopBlockFeature contains a + // num_sub_loops to indicate the next level sub-loop-block it contains. + // + // For example, code like: + // + // some_compute_0 + // loop1 { + // some_compute_1 + // loop2 { + // some_compute_2 + // } + // } + // + // loop3 { + // some_compute_3 + // } + // + // We go through the code and push loops into stack, then the features are encoded as + // [loop_block_feature_0, loop_block_feature_1, loop_block_feature_2, loop_block_feature_3] + // where loop_block_feature_i stores the features of some_compute_i (such + // as number of arithmetic operations) + // + // loop_block_feature_0.num_sub_loops = 2 + // loop_block_feature_1.num_sub_loops = 1 + // loop_block_feature_2.num_sub_loops = 0 + // loop_block_feature_3.num_sub_loops = 0 + std::vector stack_encoded_feature_; + int current_loop_block_index_; + std::vector parent_indices_; + + common::Target target_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc new file mode 100644 index 0000000000000..5f44b2e3f0a8d --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor.cc @@ -0,0 +1,299 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature_extractor.h" + +#include + +#include "cinn/common/target.h" +#include "cinn/common/type.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/transform_polyfor_to_for.h" + +namespace cinn { +namespace auto_schedule { + +using namespace ::cinn::ir; + +FeatureExtractor::FeatureExtractor() {} + +void FeatureExtractor::Visit(const Expr *x) { IRVisitor::Visit(x); } + +Feature FeatureExtractor::Extract(const ir::ModuleExpr &mod_expr, const common::Target &target) { + feature_ = Feature(target); + for (const ir::Expr &e : mod_expr.GetExprs()) { + Visit(&e); + } + return feature_; +} + +#define VisitDoNothing(NodeType) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ + } + +VisitDoNothing(IntImm); +VisitDoNothing(UIntImm); +VisitDoNothing(FloatImm); +VisitDoNothing(StringImm); + +VisitDoNothing(Block); +VisitDoNothing(_Module_); +VisitDoNothing(_Var_); +VisitDoNothing(_LoweredFunc_); +VisitDoNothing(ScheduleBlock); +VisitDoNothing(ScheduleBlockRealize); +VisitDoNothing(Ramp); +VisitDoNothing(_Buffer_); +VisitDoNothing(_BufferRange_); + +#define NotVisitExprFields(NodeType) \ + void FeatureExtractor::Visit(const NodeType *x) {} + +NotVisitExprFields(_Tensor_) + +#define VisitForDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ + } + + VisitForDtypePattern(Add, add_or_sub); +VisitForDtypePattern(Sub, add_or_sub); +VisitForDtypePattern(Minus, add_or_sub); +VisitForDtypePattern(Mul, mul); +VisitForDtypePattern(Div, div_or_mod); +VisitForDtypePattern(Mod, div_or_mod); +VisitForDtypePattern(FracOp, div_or_mod); +VisitForDtypePattern(EQ, cmp); +VisitForDtypePattern(NE, cmp); +VisitForDtypePattern(GT, cmp); +VisitForDtypePattern(GE, cmp); +VisitForDtypePattern(LT, cmp); +VisitForDtypePattern(LE, cmp); +VisitForDtypePattern(Call, math_func); +VisitForDtypePattern(PrimitiveNode, math_func); +VisitForDtypePattern(Cast, other_call); +VisitForDtypePattern(Let, other_call); + +#define VisitForMultiOperandsDtypePattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { \ + feature_.CurrentLoopBlock().float_##member += (x->operands().size() - 1); \ + } else { \ + feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \ + } \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ + } + +VisitForMultiOperandsDtypePattern(Sum, add_or_sub); +VisitForMultiOperandsDtypePattern(Product, mul); + +#define VisitCountMemberPattern(NodeType, member) \ + void FeatureExtractor::Visit(const NodeType *x) { \ + feature_.CurrentLoopBlock().member += 1; \ + std::vector sub_exprs = x->expr_fields(); \ + for (const Expr *e : sub_exprs) { \ + if (e->defined()) { \ + Visit(e); \ + } \ + } \ + } + +VisitCountMemberPattern(And, bool_op); +VisitCountMemberPattern(Or, bool_op); +VisitCountMemberPattern(Not, bool_op); +VisitCountMemberPattern(Max, select_op); +VisitCountMemberPattern(Min, select_op); +VisitCountMemberPattern(IfThenElse, select_op); +VisitCountMemberPattern(Select, select_op); +VisitCountMemberPattern(Alloc, mem_alloc); +VisitCountMemberPattern(Free, mem_free); +VisitCountMemberPattern(Load, mem_read); +VisitCountMemberPattern(Store, mem_write); + +/* Visit for loops */ + +void FeatureExtractor::Visit(const For *x) { + feature_.IntoLoopBlock(); + + LoopBlockFeature &loop_feature = feature_.CurrentLoopBlock(); + if (x->min.is_constant() && x->extent.is_constant()) { + loop_feature.loop_length = (x->extent.get_constant() - x->min.get_constant()); + } else { + loop_feature.loop_length = -1; // -1 represents unknown + } + + if (x->is_parallel()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kParallel; + loop_feature.len_vthread = loop_feature.loop_length; + } else if (x->is_unrolled()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kUnroll; + } else if (x->is_vectorized()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kVectorize; + loop_feature.vectorize_factor = x->vectorize_info().factor; + } else if (x->is_binded()) { + loop_feature.loop_opt_type = ForOptimizeFeatureEnum::kGpuBind; + const BindInfo &bind_info = x->bind_info(); + int offset = bind_info.offset; + if (bind_info.for_type == ForType::GPUBlock) { + if (offset == 0) { + loop_feature.len_blockIdx_x = loop_feature.loop_length; + } else if (offset == 1) { + loop_feature.len_blockIdx_y = loop_feature.loop_length; + } else if (offset == 2) { + loop_feature.len_blockIdx_z = loop_feature.loop_length; + } + } else if (bind_info.for_type == ForType::GPUThread) { + if (offset == 0) { + loop_feature.len_threadIdx_x = loop_feature.loop_length; + } else if (offset == 1) { + loop_feature.len_threadIdx_y = loop_feature.loop_length; + } else if (offset == 2) { + loop_feature.len_threadIdx_z = loop_feature.loop_length; + } + } + } + + std::vector sub_exprs = x->expr_fields(); + for (const Expr *e : sub_exprs) { + Visit(e); + } + + feature_.ExitLoopBlock(); +} + +void FeatureExtractor::Visit(const PolyFor *x) { + Expr copy = optim::IRCopy(Expr(x)); + feature_.IntoLoopBlock(); + optim::TransformPolyForToFor(©); + ir::For *loop = copy.As(); + CHECK(loop != nullptr); + Visit(loop); + feature_.ExitLoopBlock(); +} + +/* Visit for Reduce and Broadcast */ + +void FeatureExtractor::Visit(const Reduce *x) { + if (x->type() == common::F32() || x->type() == common::F16() || x->type() == common::F64()) { + switch (x->reduce_type) { + case Reduce::ReduceType::kSum: + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kSub: + feature_.CurrentLoopBlock().float_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kDiv: + feature_.CurrentLoopBlock().float_reduce_div += x->type().lanes(); + break; + case Reduce::ReduceType::kMul: + feature_.CurrentLoopBlock().float_reduce_mul += x->type().lanes(); + break; + case Reduce::ReduceType::kMax: + feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + break; + case Reduce::ReduceType::kMin: + feature_.CurrentLoopBlock().float_reduce_max_or_min += x->type().lanes(); + break; + } + } else { + switch (x->reduce_type) { + case Reduce::ReduceType::kSum: + feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kSub: + feature_.CurrentLoopBlock().int_reduce_sum_or_sub += x->type().lanes(); + break; + case Reduce::ReduceType::kDiv: + feature_.CurrentLoopBlock().int_reduce_div += x->type().lanes(); + break; + case Reduce::ReduceType::kMul: + feature_.CurrentLoopBlock().int_reduce_mul += x->type().lanes(); + break; + case Reduce::ReduceType::kMax: + feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes(); + break; + case Reduce::ReduceType::kMin: + feature_.CurrentLoopBlock().int_reduce_max_or_min += x->type().lanes(); + break; + } + } + std::vector sub_exprs = x->expr_fields(); + for (const Expr *e : sub_exprs) { + Visit(e); + } +} +VisitForDtypePattern(Broadcast, broadcast); + +/* Visit for IntrinsicOp */ +void FeatureExtractor::Visit(const IntrinsicOp *x) { + switch (x->getKind()) { +#define __(op__) \ + case IntrinsicKind::k##op__: \ + Visit(llvm::dyn_cast(x)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } +} + +VisitDoNothing(intrinsics::BufferGetDataHandle); +VisitDoNothing(intrinsics::BufferGetDataConstHandle); +VisitDoNothing(intrinsics::PodValueToX); +VisitDoNothing(intrinsics::BufferCreate); +VisitDoNothing(intrinsics::GetAddr); +VisitDoNothing(intrinsics::ArgsConstruct); + +VisitForDtypePattern(intrinsics::BuiltinIntrin, other_call) + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor.h b/paddle/cinn/auto_schedule/cost_model/feature_extractor.h new file mode 100644 index 0000000000000..073eee27cac77 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/cost_model/feature.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/ir_visitor.h" + +namespace cinn { +namespace auto_schedule { + +class FeatureExtractor : public ir::IRVisitor { + public: + FeatureExtractor(); + Feature Extract(const ir::ModuleExpr& mod_expr, const common::Target& target); + + void Visit(const Expr* x) override; + +#define __(op__) void Visit(const ir::op__* x) override; + NODETY_FORALL(__) +#undef __ + +#define __(op__) virtual void Visit(const ir::intrinsics::op__* x); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + + private: + Feature feature_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc new file mode 100644 index 0000000000000..ed0cd984c93de --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature_extractor.h" + +#include +#include + +#include +#include +#include + +#include "cinn/common/context.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace auto_schedule { + +TEST(FeatureExtractor, SimpleAssign) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + ir::Expr M(32); + ir::Expr N(32); + + lang::Placeholder A("A", {M, N}); + ir::Tensor B = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + poly::StageMap stages = poly::CreateStages({A, B}); + std::vector funcs = lang::LowerVec("SimpleAssign", stages, {A, B}, {}, {}, nullptr, target, true); + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr to test: " << ast_expr; + + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + + FeatureExtractor extractor; + + Feature feature = extractor.Extract(mod_expr, target); + + std::vector to_check = feature.ToFixedSizeVector(); + + ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + VLOG(6) << "Feature data before slog:"; + for (size_t i = 0; i < to_check.size(); ++i) { + VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); + if (i != 0 && i != 17 && i != 18 && i != 29) { + ASSERT_EQ(to_check[i], 0); + } + } + // target +#ifdef CINN_WITH_CUDA + ASSERT_EQ(to_check[0], 1); +#else + ASSERT_EQ(to_check[0], 0); +#endif + // mem_read + ASSERT_EQ(to_check[17], slog(M.get_constant() * N.get_constant())); // mem_read + // mem_write + ASSERT_EQ(to_check[18], slog(M.get_constant() * N.get_constant())); // mem_write + // non-opt loops, including root block + ASSERT_EQ(to_check[29], slog(3)); +} + +TEST(FeatureExtractor, MatrixMultiply) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + ir::Expr M(2); + ir::Expr N(2); + ir::Expr K(4); + + lang::Placeholder A("A", {M, K}); + lang::Placeholder B("B", {K, N}); + + ir::Var k(K.as_int32(), "reduce_axis_k"); + ir::Tensor C = lang::Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + poly::StageMap stages = poly::CreateStages({C}); + std::vector funcs = lang::LowerVec("MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + + std::vector vec_ast{funcs[0]->body}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + std::vector blocks = ir_sch.GetAllBlocks(); + std::vector loops = ir_sch.GetLoops(blocks[0]); + ir_sch.Bind(loops.back(), "threadIdx.x"); + + ir::Expr ast_expr = mod_expr.GetExprs()[0]; + VLOG(6) << "Expr to test: " << ast_expr; + + FeatureExtractor extractor; + Feature feature = extractor.Extract(mod_expr, target); + + std::vector to_check = feature.ToFixedSizeVector(); + + ASSERT_EQ(to_check.size(), static_cast(LoopBlockFeature::kTotalSize + 1)); + std::unordered_set non_zero_indice = {0, 1, 2, 17, 18, 29, 30, 37}; + for (size_t i = 0; i < to_check.size(); ++i) { + VLOG(6) << i << " " << (std::pow(2, to_check[i]) - 1); + if (!non_zero_indice.count(i)) { + ASSERT_EQ(to_check[i], 0); + } + } + // target +#ifdef CINN_WITH_CUDA + ASSERT_EQ(to_check[0], 1); +#else + ASSERT_EQ(to_check[0], 0); +#endif + float out_loop = M.get_constant() * N.get_constant(); + float total_loop = out_loop * K.get_constant(); + // float_mul + ASSERT_EQ(to_check[1], slog(total_loop)); + // float_add_or_sub + ASSERT_EQ(to_check[2], slog(total_loop)); + // mem_read + ASSERT_EQ(to_check[17], slog(total_loop * 3)); + // mem_write + ASSERT_EQ(to_check[18], slog(total_loop + out_loop)); + + // non-opt loops, including root block + ASSERT_EQ(to_check[29], slog(3)); + // GpuBind loop + ASSERT_EQ(to_check[30], slog(1)); + // GpuBind loop + ASSERT_EQ(to_check[37], slog(out_loop)); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/feature_test.cc b/paddle/cinn/auto_schedule/cost_model/feature_test.cc new file mode 100644 index 0000000000000..908672d41b404 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/feature_test.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/feature.h" + +#include +#include + +namespace cinn { +namespace auto_schedule { + +TEST(Feature, Basic) { + // TODO(zhhsplendid): add some basic tests +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc new file mode 100644 index 0000000000000..8549442688033 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/common/python_interpreter_guard.h" + +namespace cinn { +namespace auto_schedule { + +std::atomic XgbCostModel::xgb_cost_model_count_(0); + +// Convert 1D vector to py numpy +template +pybind11::array VectorToNumpy(const std::vector& vec) { + return pybind11::array(pybind11::cast(vec)); +} + +// Convert 2D vector to py numpy +template +pybind11::array VectorToNumpy(const std::vector>& vec) { + if (vec.size() == 0) { + return pybind11::array(pybind11::dtype::of(), {0, 0}); + } + + std::vector shape{vec.size(), vec[0].size()}; + pybind11::array ret(pybind11::dtype::of(), shape); + + Dtype* py_data = static_cast(ret.mutable_data()); + for (size_t i = 0; i < vec.size(); ++i) { + assert(vec[i].size() == shape[1] && "Sub vectors must have same size in VectorToNumpy"); + memcpy(py_data + (shape[1] * i), vec[i].data(), shape[1] * sizeof(Dtype)); + } + return ret; +} + +// the Pybind default Python interpreter doesn't contain some paths in +// sys.path, so we have to add it. +// +// Note: the Pybind default Python interpreter only uses default Python. +// Something may be wrong when users use virtual Python environment. +void AddDistPkgToPythonSysPath() { + pybind11::module sys_py_mod = pybind11::module::import("sys"); + // short version such as "3.7", "3.8", ... + std::string py_short_version = sys_py_mod.attr("version").cast().substr(0, 3); + + std::string site_pkg_str = "/usr/local/lib/python" + py_short_version + "/dist-packages"; + sys_py_mod.attr("path").attr("append")(site_pkg_str); + + // TODO(zhhsplendid): warning to users if setuptools hasn't been installed + DIR* site_pkg_dir = opendir(site_pkg_str.c_str()); + if (site_pkg_dir != nullptr) { + std::regex setuptool_regex("setuptools-.*-py" + py_short_version + "\\.egg"); + struct dirent* entry = nullptr; + while ((entry = readdir(site_pkg_dir)) != nullptr) { + if (std::regex_match(entry->d_name, setuptool_regex)) { + sys_py_mod.attr("path").attr("append")(site_pkg_str + "/" + entry->d_name); + } + } + closedir(site_pkg_dir); + } +} + +XgbCostModel::XgbCostModel() { + common::PythonInterpreterGuard::Guard(); + int previous = xgb_cost_model_count_.fetch_add(1); + if (previous == 0) { + AddDistPkgToPythonSysPath(); + } + xgb_module_ = pybind11::module::import("xgboost"); + xgb_booster_ = xgb_module_.attr("Booster")(); +} + +void XgbCostModel::Train(const std::vector>& samples, const std::vector& labels) { + update_samples_ = samples; + update_labels_ = labels; + pybind11::array np_samples = VectorToNumpy(samples); + pybind11::array np_labels = VectorToNumpy(labels); + + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); + xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); +} + +std::vector XgbCostModel::Predict(const std::vector>& samples) const { + pybind11::array np_samples = VectorToNumpy(samples); + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples); + pybind11::array py_result = xgb_booster_.attr("predict")(dmatrix); + return py_result.cast>(); +} + +void XgbCostModel::Update(const std::vector>& samples, const std::vector& labels) { + update_samples_.insert(update_samples_.end(), samples.begin(), samples.end()); + update_labels_.insert(update_labels_.end(), labels.begin(), labels.end()); + pybind11::array np_samples = VectorToNumpy(update_samples_); + pybind11::array np_labels = VectorToNumpy(update_labels_); + + pybind11::object dmatrix = xgb_module_.attr("DMatrix")(np_samples, np_labels); + xgb_booster_ = xgb_module_.attr("train")(pybind11::dict(), dmatrix, pybind11::int_(kTrainRound_)); +} + +void XgbCostModel::Save(const std::string& path) { xgb_booster_.attr("save_model")(pybind11::str(path)); } + +void XgbCostModel::Load(const std::string& path) { xgb_booster_.attr("load_model")(pybind11::str(path)); } + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h new file mode 100644 index 0000000000000..69dbb8a7f3904 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h @@ -0,0 +1,75 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "cinn/common/cost_model.h" + +namespace cinn { +namespace auto_schedule { + +/** + * A C++ cost model which calls Python xgboost via pybind + * + * Note: this class handles Python interpreter life time in class. + * If you have to call other Python functions out of this class so that meet + * life time conflict, you can check cinn::common::PythonInterpreterGuard + * + * For cinn::common::PythonInterpreterGuard, see: + * cinn/common/python_interpreter_guard.h .cc + * + * For pybind interpreter lifetime management, see: + * + * https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#interpreter-lifetime + * https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv422initialize_interpreterbiPPCKcb + */ +class XgbCostModel : public CostModel { + public: + XgbCostModel(); + ~XgbCostModel() = default; + + void Train(const std::vector>& samples, const std::vector& labels) override; + + std::vector Predict(const std::vector>& samples) const override; + + void Update(const std::vector>& samples, const std::vector& labels) override; + + void Save(const std::string& path) override; + + void Load(const std::string& path) override; + + private: + // Python xgboost module + pybind11::module xgb_module_; + // Object points to Python xgb.Booster() + pybind11::object xgb_booster_; + // atomic int to handle python interpreter lifetime and package dependency + static std::atomic xgb_cost_model_count_; + // Default train rounds + static constexpr int kTrainRound_ = 10; + + std::vector> update_samples_; + std::vector update_labels_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc new file mode 100644 index 0000000000000..f237699a94406 --- /dev/null +++ b/paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/cost_model/xgb_cost_model.h" + +#include +#include +#include + +#include +#include +#include +#include + +namespace cinn { +namespace auto_schedule { + +TEST(CostModel, Basic) { + XgbCostModel cost_model; + + srand(time(NULL)); + + int batch_size = 16; + int feature_size = 8; + std::vector labels(batch_size, 1.0); + std::vector> samples(batch_size, std::vector(feature_size)); + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < feature_size; ++j) { + samples[i][j] = rand() % 10; + } + } + + cost_model.Train(samples, labels); + std::vector pred = cost_model.Predict(samples); + + std::string path = "./test_cost_model.cpp_save_model"; + cost_model.Save(path); + + XgbCostModel load_cost_model; + load_cost_model.Load(path); + std::vector load_pred = cost_model.Predict(samples); + + ASSERT_EQ(pred.size(), load_pred.size()); + for (size_t i = 0; i < pred.size(); ++i) { + ASSERT_FLOAT_EQ(pred[i], load_pred[i]); + VLOG(6) << "pred[" << i << "] = " << pred[i]; + } + std::remove(path.c_str()); + + cost_model.Update(samples, labels); + pred = cost_model.Predict(samples); + for (size_t i = 0; i < pred.size(); ++i) { + VLOG(6) << "pred[" << i << "] = " << pred[i]; + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/CMakeLists.txt b/paddle/cinn/auto_schedule/database/CMakeLists.txt new file mode 100644 index 0000000000000..1c3ca9330ba8c --- /dev/null +++ b/paddle/cinn/auto_schedule/database/CMakeLists.txt @@ -0,0 +1,6 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS database.cc jsonfile_database.cc) + +cc_test(test_database SRCS database_test.cc DEPS cinncore) +cc_test(test_jsonfile_database SRCS jsonfile_database_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/database/database.cc b/paddle/cinn/auto_schedule/database/database.cc new file mode 100644 index 0000000000000..87cfd63007db4 --- /dev/null +++ b/paddle/cinn/auto_schedule/database/database.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/database.h" + +#include +#include +#include + +#include "cinn/auto_schedule/database/jsonfile_database.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/schedule_desc.h" + +namespace cinn { +namespace auto_schedule { + +bool TuningRecord::Compare::operator()(const TuningRecord& lhs, const TuningRecord& rhs) const { + return lhs.execution_cost < rhs.execution_cost; +} + +proto::TuningRecord TuningRecord::ToProto() const { + proto::TuningRecord record_proto; + record_proto.set_task_key(task_key); + record_proto.set_execution_cost(execution_cost); + record_proto.set_predicted_cost(predicted_cost); + record_proto.mutable_trace()->CopyFrom(trace); + return record_proto; +} + +Database::Database(int capacity_per_task) : capacity_per_task_(capacity_per_task) { + CHECK_GT(capacity_per_task_, 0) << "capacity_per_task_ should be greater than 0"; +} + +std::unique_ptr Database::Make(const DatabaseConfig& config) { + if (config.type == DatabaseType::kMemory) { + return std::make_unique(config.capacity_per_task); + } else if (config.type == DatabaseType::kJSONFile) { + return std::make_unique(config.capacity_per_task, config.record_file_path, true); + } + + LOG(FATAL) << "Unimplemented database type."; + return nullptr; +} + +void Database::Insert(const TuningRecord& record) { + auto& records = key2record_[record.task_key]; + records.emplace(record); + if (records.size() > capacity_per_task_) { + records.erase(std::prev(records.end())); + } +} + +bool Database::AddRecord(const TuningRecord& record) { + CHECK(!record.task_key.empty()) << "task_key of TuningRecord can't be empty"; + + Insert(record); + return Commit(record); +} + +std::vector Database::LookUp(const std::string& task_key) { + auto fit = key2record_.find(task_key); + if (fit == key2record_.end()) { + return {}; + } + + std::vector results; + results.reserve(fit->second.size()); + results.assign(fit->second.begin(), fit->second.end()); + return results; +} + +std::vector Database::GetTopK(const std::string& task_key, int k) { + auto fit = key2record_.find(task_key); + if (fit == key2record_.end() || k <= 0) { + return {}; + } + if (k > capacity_per_task_) { + LOG(WARNING) << "Top k=" << k << " is greater than the capacity, will adjust k=" << capacity_per_task_; + k = capacity_per_task_; + } + + std::vector results; + results.reserve(k); + for (const TuningRecord& record : fit->second) { + results.emplace_back(record); + if (results.size() == k) { + break; + } + } + return results; +} + +size_t Database::Size() { + auto res = + std::accumulate(key2record_.begin(), key2record_.end(), size_t(0), [](size_t res, const auto& kv) -> size_t { + return std::move(res) + kv.second.size(); + }); + return res; +} + +size_t Database::Count(const std::string& task_key) { + auto fit = key2record_.find(task_key); + if (fit == key2record_.end()) { + return 0; + } + return fit->second.size(); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/database.h b/paddle/cinn/auto_schedule/database/database.h new file mode 100644 index 0000000000000..4487272b23875 --- /dev/null +++ b/paddle/cinn/auto_schedule/database/database.h @@ -0,0 +1,102 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "cinn/auto_schedule/auto_schedule.pb.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/ir/schedule_desc.pb.h" + +namespace cinn { +namespace auto_schedule { + +// Record related data about tuning process of a measure candidate +struct TuningRecord { + // the unique key to identify a task + std::string task_key; + // the predicted cost of CostModel + float predicted_cost; // unit: us + // the ScheduleDesc of this tuning process + ir::proto::ScheduleDesc trace; + // the cost time of the candidate executed during measure + double execution_cost; // unit: us + + TuningRecord() = default; + TuningRecord(const proto::TuningRecord& record) + : task_key(record.task_key()), + predicted_cost(record.predicted_cost()), + trace(record.trace()), + execution_cost(record.execution_cost()) {} + TuningRecord(const std::string& task_key, const SearchState& state, double execution_cost) + : task_key(task_key), + predicted_cost(state->predicted_cost), + trace(state->ir_schedule.GetTraceDesc().ToProto()), + execution_cost(execution_cost) {} + + // convert to proto object + proto::TuningRecord ToProto() const; + + // a binary compare function that denotes when the left + // will be sorted in the front of the right + struct Compare { + bool operator()(const TuningRecord& lhs, const TuningRecord& rhs) const; + }; +}; + +enum class DatabaseType : int { kMemory, kJSONFile }; + +struct DatabaseConfig { + DatabaseType type = DatabaseType::kMemory; + int capacity_per_task = 2; + std::string record_file_path = "/tmp/tuning_record.json"; +}; + +// A database supports insert or lookup historial tuning result with specified traits. +// It can be implemented with a concrete storage to save/load underlying data, +// such as memory, file, database server and so on, this base class can be regarded as +// one using memory as its underlying storage medium. +class Database { + public: + explicit Database(int capacity_per_task); + ~Database() = default; + + // Create a Database with the specific config + static std::unique_ptr Make(const DatabaseConfig& config); + + // add a record into the database + bool AddRecord(const TuningRecord& record); + // return all records whose task_keys are equal to the specified key + std::vector LookUp(const std::string& task_key); + // return the states of the top k in sorted candidates + std::vector GetTopK(const std::string& task_key, int k); + // return the total number of stored candidates + size_t Size(); + // return the number of stored candidates with specified key + size_t Count(const std::string& task_key); + + protected: + // commit the newly added record into underlying storage + virtual bool Commit(const TuningRecord& record) { return true; } + // insert a newly added record into memory storage + void Insert(const TuningRecord& record); + + // map task_key to its records + std::unordered_map> key2record_; + // the max number of candidates stored + const int capacity_per_task_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/database_test.cc b/paddle/cinn/auto_schedule/database/database_test.cc new file mode 100644 index 0000000000000..2e06f4a56be0b --- /dev/null +++ b/paddle/cinn/auto_schedule/database/database_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/database.h" + +#include + +#include + +#include "cinn/auto_schedule/auto_schedule.pb.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +class TestDatabase : public ::testing::Test { + public: + TestDatabase() : test_db(2) { + auto state = SearchState(ir::IRSchedule()); + test_db.AddRecord(TuningRecord("k1", state, 1.0)); + test_db.AddRecord(TuningRecord("k2", state, 2.0)); + test_db.AddRecord(TuningRecord("k2", state, 3.0)); + test_db.AddRecord(TuningRecord("k3", state, 3.0)); + test_db.AddRecord(TuningRecord("k3", state, 4.0)); + test_db.AddRecord(TuningRecord("k3", state, 5.0)); + test_db.AddRecord(TuningRecord("k4", state, 4.0)); + } + + void SetUp() override {} + Database test_db; +}; + +TEST_F(TestDatabase, Basic) { + ASSERT_EQ(test_db.Size(), 6); + auto records = test_db.LookUp("k3"); + // check the max number of stored candidates will + // be restricted to capacity_per_task + ASSERT_EQ(test_db.Count("k3"), 2); + ASSERT_EQ(records.size(), 2); + EXPECT_EQ(records[0].execution_cost, 3.0); + EXPECT_EQ(records[1].execution_cost, 4.0); +} + +TEST_F(TestDatabase, GetTopK) { + ASSERT_TRUE(test_db.GetTopK("k5", 2).empty()); + ASSERT_EQ(test_db.GetTopK("k4", 3).size(), 1); + + test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.2), 2.0)); + test_db.AddRecord(TuningRecord("k4", SearchState(ir::IRSchedule(), 1.0), 3.0)); + + auto records = test_db.GetTopK("k4", 3); + ASSERT_EQ(records.size(), 2); + EXPECT_FLOAT_EQ(records[0].predicted_cost, 1.2); + EXPECT_FLOAT_EQ(records[1].predicted_cost, 1.0); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database.cc b/paddle/cinn/auto_schedule/database/jsonfile_database.cc new file mode 100644 index 0000000000000..3a7eb677183f3 --- /dev/null +++ b/paddle/cinn/auto_schedule/database/jsonfile_database.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/jsonfile_database.h" + +#include +#include +#include + +#include + +#include "cinn/auto_schedule/auto_schedule.pb.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/utils/multi_threading.h" + +namespace cinn { +namespace auto_schedule { + +// append a line to file +void AppendLineToFile(const std::string& file_path, const std::string& line) { + std::ofstream os(file_path, std::ofstream::app); + CHECK(os.good()) << "Cannot open the file to write: " << file_path; + os << line << std::endl; +} + +// read lines from a json file +std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file) { + std::ifstream is(file_path); + if (is.good()) { + std::vector json_strs; + for (std::string str; std::getline(is, str);) { + json_strs.push_back(str); + } + + return json_strs; + } + CHECK(allow_new_file) << "File doesn't exist: " << file_path; + std::ofstream os(file_path); + CHECK(os.good()) << "Cannot create new file: " << file_path; + return {}; +} + +JSONFileDatabase::JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file) + : Database(capacity_per_task), record_file_path_(record_file_path) { + VLOG(3) << "Auto schedule will save/load tuning records on file:" << record_file_path; + auto json_lines = ReadLinesFromFile(record_file_path_, allow_new_file); + std::vector all_records_proto(json_lines.size()); + + // convert JSON string to proto object + auto worker_fn = [this, &json_lines, &all_records_proto](int index) { + cinn::auto_schedule::proto::TuningRecord record_proto; + auto status = google::protobuf::util::JsonStringToMessage(json_lines[index], &record_proto); + CHECK(status.ok()) << "Failed to parse JSON: " << json_lines[index]; + all_records_proto[index].Swap(&record_proto); + }; + utils::parallel_run(worker_fn, utils::SequenceDispatcher(0, json_lines.size()), -1); + + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + + for (const auto& record_proto : all_records_proto) { + std::string task_key = record_proto.task_key(); + if (task_registry->Has(task_key)) { + VLOG(4) << "Add a measured TuningRecord with task_key=" << task_key; + Insert(TuningRecord(record_proto)); + } + } +} + +// convert a TuningRecord object to string in JSON format +std::string JSONFileDatabase::RecordToJSON(const TuningRecord& record) { + proto::TuningRecord record_proto = record.ToProto(); + std::string json_string; + auto status = google::protobuf::util::MessageToJsonString(record_proto, &json_string); + CHECK(status.ok()) << "Failed to serialize record to JSON, task key = " << record.task_key; + VLOG(4) << "json_string = \n" << json_string; + + return json_string; +} + +bool JSONFileDatabase::Commit(const TuningRecord& record) { + std::string json_string = RecordToJSON(record); + AppendLineToFile(record_file_path_, json_string); + + return true; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database.h b/paddle/cinn/auto_schedule/database/jsonfile_database.h new file mode 100644 index 0000000000000..540013c224d5f --- /dev/null +++ b/paddle/cinn/auto_schedule/database/jsonfile_database.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/database/database.h" + +namespace cinn { +namespace auto_schedule { + +// JSONFileDatabase is a database implemented by JSON file to save/load underlying data. +class JSONFileDatabase : public Database { + public: + /*! + * \brief Build a JSONFileDatabase object from a json file. + * \param capacity_per_task The max number of candidates stored. + * \param record_file_path The path of the json file. + * \param allow_new_file Whether to create new file when the given path is not found. + */ + JSONFileDatabase(int capacity_per_task, const std::string& record_file_path, bool allow_new_file); + ~JSONFileDatabase() = default; + + // convert a TuningRecord object to string in JSON format + std::string RecordToJSON(const TuningRecord& record); + + protected: + // commit the newly added record into json file + bool Commit(const TuningRecord& record) override; + + // the name of the json file to save tuning records. + std::string record_file_path_; +}; + +// append a line to file +void AppendLineToFile(const std::string& file_path, const std::string& line); + +// read lines from a json file +std::vector ReadLinesFromFile(const std::string& file_path, bool allow_new_file = true); + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc new file mode 100644 index 0000000000000..6ace45ea19478 --- /dev/null +++ b/paddle/cinn/auto_schedule/database/jsonfile_database_test.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/database/jsonfile_database.h" + +#include +#include + +#include +#include + +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/cinn.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +// Return lowerd ir AST for example functions used in this test +std::vector LowerCompute(const std::vector& shape, const Target& target) { + CHECK(shape.size() == 2) << "shape should be 2"; + std::vector domain; + for (auto i = 0; i < shape.size(); ++i) { + domain.emplace_back(shape[i]); + } + + Placeholder A("A", domain); + ir::Tensor B, C; + + B = Compute( + domain, [&A](Var i, Var j) { return A(i, j); }, "B"); + C = Compute( + domain, [&B](Var i, Var j) { return B(i, j); }, "C"); + + return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); +} + +// Create a new IRSchedule with copied ir::LoweredFunc AST +ir::IRSchedule MakeIRSchedule(const std::vector& lowered_funcs, const std::string& task_key) { + std::vector exprs; + for (auto&& func : lowered_funcs) { + exprs.emplace_back(optim::IRCopy(func->body)); + } + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + task_registry->Regist(task_key, ir::ModuleExpr(exprs)); + + return ir::IRSchedule(ir::ModuleExpr(exprs)); +} + +class TestJSONFileDatabase : public ::testing::Test { + public: + TestJSONFileDatabase() : record_file_path("/tmp/test_record.json"), test_db(2, record_file_path, true) {} + + void SetUp() override { lowered_funcs = LowerCompute({32, 32}, target); } + + void TearDown() override { + auto isFileExists = [](const std::string& file_path) -> bool { + std::ifstream f(file_path.c_str()); + return f.good(); + }; + if (isFileExists(record_file_path)) { + if (remove(record_file_path.c_str()) == 0) { + LOG(INFO) << "Successfully deleted file: " << record_file_path; + } else { + LOG(INFO) << "failed to delete file: " << record_file_path; + } + } else { + LOG(INFO) << "file: " << record_file_path << "does not exist."; + } + } + + std::string record_file_path; + JSONFileDatabase test_db; + std::vector lowered_funcs; + Target target = common::DefaultHostTarget(); +}; + +TEST_F(TestJSONFileDatabase, Serialize) { + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "test"); + auto fused = ir_sch.Fuse("B", {0, 1}); + VLOG(3) << "after Fuse, Expr: " << fused; + + TuningRecord record1("test", SearchState(std::move(ir_sch), 2.0), 1.0); + std::string str = test_db.RecordToJSON(record1); + VLOG(3) << "RecordToJSON: " << str; + // Because the serialization of protobuf does not guarantee the order, we give all possible results. + std::string case1 = + "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" + "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; + std::string case2 = + "{\"taskKey\":\"test\",\"executionCost\":1,\"predictedCost\":2,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" + "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; + EXPECT_EQ(true, str == case1 || str == case2); +} + +TEST_F(TestJSONFileDatabase, SaveLoad) { + ir::IRSchedule ir_sch1 = MakeIRSchedule(lowered_funcs, "k1"); + auto fused1 = ir_sch1.Fuse("B", {0, 1}); + ir::IRSchedule ir_sch2 = MakeIRSchedule(lowered_funcs, "k2"); + + test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch1), 1.5), 1.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(std::move(ir_sch2), 3.5), 3.0)); + + std::vector strs = ReadLinesFromFile(record_file_path); + ASSERT_EQ(strs.size(), 2); + // Because the serialization of protobuf does not guarantee the order, we give all possible results. + std::string case1 = + "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"loops_index\",\"dtype\":\"INTS\",\"ints\":[0,1]},{\"name\":\"block_" + "name\",\"dtype\":\"STRING\",\"s\":\"B\"}]}]}}"; + std::string case2 = + "{\"taskKey\":\"k1\",\"executionCost\":1,\"predictedCost\":1.5,\"trace\":{\"steps\":[{\"type\":\"FuseWithName\"," + "\"outputs\":[\"e0\"],\"attrs\":[{\"name\":\"block_name\",\"dtype\":\"STRING\",\"s\":\"B\"},{\"name\":\"loops_" + "index\",\"dtype\":\"INTS\",\"ints\":[0,1]}]}]}}"; + EXPECT_EQ(true, strs[0] == case1 || strs[0] == case2); + EXPECT_EQ(strs[1], "{\"taskKey\":\"k2\",\"executionCost\":3,\"predictedCost\":3.5,\"trace\":{}}"); +} + +TEST_F(TestJSONFileDatabase, Basic) { + test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 8.0), 3.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 7.0), 4.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 6.0), 5.0)); + test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 4.0)); + + ASSERT_EQ(test_db.Size(), 6); + auto records = test_db.LookUp("k3"); + // check the max number of stored candidates will + // be restricted to capacity_per_task + ASSERT_EQ(test_db.Count("k3"), 2); + ASSERT_EQ(records.size(), 2); + EXPECT_EQ(records[0].execution_cost, 3.0); + EXPECT_EQ(records[1].execution_cost, 4.0); +} + +TEST_F(TestJSONFileDatabase, GetTopK) { + test_db.AddRecord(TuningRecord("k1", SearchState(MakeIRSchedule(lowered_funcs, "k1"), 1.0), 1.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 3.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 4.0)); + test_db.AddRecord(TuningRecord("k3", SearchState(MakeIRSchedule(lowered_funcs, "k3"), 1.0), 5.0)); + test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 2.0), 4.0)); + test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.2), 2.0)); + test_db.AddRecord(TuningRecord("k4", SearchState(MakeIRSchedule(lowered_funcs, "k4"), 1.0), 3.0)); + + auto records = test_db.GetTopK("k4", 3); + ASSERT_EQ(records.size(), 2); + EXPECT_FLOAT_EQ(records[0].predicted_cost, 1.2); + EXPECT_FLOAT_EQ(records[1].predicted_cost, 1.0); +} + +TEST_F(TestJSONFileDatabase, Reload) { + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs, "k1"); + auto fused = ir_sch.Fuse("B", {0, 1}); + test_db.AddRecord(TuningRecord("k1", SearchState(std::move(ir_sch), 1.0), 1.0)); + test_db.AddRecord(TuningRecord("k2", SearchState(MakeIRSchedule(lowered_funcs, "k2"), 1.0), 2.0)); + auto records = test_db.LookUp("k1"); + ASSERT_EQ(records.size(), 1); + + JSONFileDatabase new_db(2, record_file_path, false); + ASSERT_EQ(new_db.Size(), 2); + auto loaded_records = new_db.LookUp("k1"); + ASSERT_EQ(records.size(), loaded_records.size()); + EXPECT_EQ(records[0].task_key, loaded_records[0].task_key); + EXPECT_EQ(records[0].execution_cost, loaded_records[0].execution_cost); + EXPECT_EQ(records[0].predicted_cost, loaded_records[0].predicted_cost); + + // check the equality of trace info between original TuningRecord and the loaded TuningRecord + const auto& lhs_trace = records[0].trace; + const auto& rhs_trace = loaded_records[0].trace; + google::protobuf::util::MessageDifferencer dif; + static const google::protobuf::Descriptor* descriptor = cinn::ir::proto::ScheduleDesc_Step::descriptor(); + dif.TreatAsSet(descriptor->FindFieldByName("attrs")); + EXPECT_TRUE(dif.Compare(lhs_trace, rhs_trace)); + + // check the equality of module expr between original TuningRecord + // and the loaded TuningRecord by replaying with tracing ScheduleDesc + ir::IRSchedule lhs_sch = MakeIRSchedule(lowered_funcs, "k1"); + ir::IRSchedule rhs_sch = MakeIRSchedule(lowered_funcs, "k1"); + ir::ScheduleDesc::ReplayWithProto(lhs_trace, &lhs_sch); + ir::ScheduleDesc::ReplayWithProto(rhs_trace, &rhs_sch); + auto lhs_exprs = lhs_sch.GetModule().GetExprs(); + auto rhs_exprs = rhs_sch.GetModule().GetExprs(); + + ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); + for (auto i = 0; i < lhs_exprs.size(); ++i) { + std::string lhs = utils::GetStreamCnt(lhs_exprs.at(i)); + std::string rhs = utils::GetStreamCnt(rhs_exprs.at(i)); + size_t remove_prefix_len = 28; + ASSERT_EQ(lhs.erase(0, remove_prefix_len), rhs.erase(0, remove_prefix_len)); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/CMakeLists.txt b/paddle/cinn/auto_schedule/measure/CMakeLists.txt new file mode 100644 index 0000000000000..ea2e822368df2 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/CMakeLists.txt @@ -0,0 +1,6 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS schedule_measurer.cc simple_builder.cc simple_runner.cc) + +cc_test(test_simple_runner SRCS simple_runner_test.cc DEPS cinncore) +cc_test(test_measurer SRCS measurer_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/measure/measure.h b/paddle/cinn/auto_schedule/measure/measure.h new file mode 100644 index 0000000000000..124aa474d9948 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/measure.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/instruction.h" +#include "cinn/runtime/cinn_runtime.h" + +namespace cinn { +namespace auto_schedule { + +// The input to a measurer +struct MeasureInput { + // The task object related to this measurement. + const TuneTask* task; + // lowered Exprs to be measured + std::vector lowered_funcs; + // It is used to pass for some arguments that maybe + // specified value in advance. default is null + const std::map* execution_args = nullptr; +}; + +// The result of a measurement +struct MeasureResult { + // The time cost of execution in average of running + // with a specific repeated times. + double execution_cost = 0.0; // unit: us + // The time cost of the whole measurement process including + // building and running + double elapsed_time = 0.0; // unit: us + // used to return detail messages once an error occurred during measurement, + // empty if nothing goes wrong + std::string error_msg; +}; + +// The result of building with input schedule +struct BuildResult { + // The scope that owns detail compilation infos of parameters in the runtime program + const hlir::framework::Scope* compiled_scope; + // The executable program + std::unique_ptr runtime_program; +}; + +// This interface defines how to generate executable objects +// with input schedule. A builder should not contain stateful data +// related to any task so it can be called parallelly among multiple +// processes of task tuning. +class ScheduleBuilder { + public: + virtual BuildResult Build(const MeasureInput& input) = 0; +}; + +// This interface defines how to run the built result. Like above ScheduleBuilder, +// a runner shoule be implemented with not bound to a specific task. +class ScheduleRunner { + public: + virtual MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) = 0; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/measurer_test.cc b/paddle/cinn/auto_schedule/measure/measurer_test.cc new file mode 100644 index 0000000000000..5297cabad5296 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/measurer_test.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "cinn/auto_schedule/measure/schedule_measurer.h" +#include "cinn/auto_schedule/measure/simple_builder.h" +#include "cinn/auto_schedule/measure/simple_runner.h" +#include "cinn/auto_schedule/task/task_creator.h" +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/optimize.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::BuildScope; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::GraphCompiler; + +frontend::Program CreateAddReluProgram() { + constexpr int M = 32; + constexpr int N = 24; + frontend::NetBuilder builder("test"); + + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Relu(c); + return builder.Build(); +} + +class TestMeasurer : public ::testing::Test { + public: + std::unique_ptr graph_compiler; + std::vector tasks; + std::vector inputs; + + void SetUp() override { + FLAGS_cinn_ir_schedule = true; +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + std::unordered_set fetch_ids; + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + auto scope = BuildScope(target, graph); + graph_compiler = std::make_unique(target, scope, graph); + TaskCreator task_creator; + tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs>("infershape"); + + auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + inputs.reserve(tasks.size()); + for (int i = 0; i < tasks.size(); ++i) { + auto* task = &tasks[i]; + task->Initialize(shape_dict, dtype_dict, op_lowerer.get()); + MeasureInput input; + input.task = task; + input.lowered_funcs = task->lowered_funcs; + inputs.emplace_back(input); + } + } +}; + +class ThrowExceptionBuilder : public ScheduleBuilder { + struct Exception : public std::exception { + const char* what() const throw() { return "BuildError"; } + }; + BuildResult Build(const MeasureInput& input) override { throw Exception(); } +}; + +class ThrowExceptionRunner : public ScheduleRunner { + struct Exception : public std::exception { + const char* what() const throw() { return "RunError"; } + }; + MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override { throw Exception(); } +}; + +TEST_F(TestMeasurer, Basic) { + auto builder = std::make_unique(graph_compiler.get()); + auto runner = std::make_unique(1); + auto measurer = std::make_unique(builder.get(), runner.get()); + std::vector results = measurer->Measure(inputs); + ASSERT_EQ(inputs.size(), results.size()); +} + +TEST_F(TestMeasurer, CatchException) { + auto builder = std::make_unique(graph_compiler.get()); + auto runner = std::make_unique(1); + auto throw_builder = std::make_unique(); + auto throw_runner = std::make_unique(); + auto measurer_with_build_error = std::make_unique(throw_builder.get(), runner.get(), 2); + std::vector results = measurer_with_build_error->Measure(inputs); + ASSERT_EQ(inputs.size(), results.size()); + EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n"); + + // TODO(CtfGo): test parallel build after we support thread-safe compilation + auto measurer_with_run_error = std::make_unique(builder.get(), throw_runner.get(), 1); + results = measurer_with_run_error->Measure(inputs); + ASSERT_EQ(inputs.size(), results.size()); + EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n"); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/schedule_measurer.cc b/paddle/cinn/auto_schedule/measure/schedule_measurer.cc new file mode 100644 index 0000000000000..3662d831d3eb2 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/schedule_measurer.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/measure/schedule_measurer.h" + +#include + +#include "cinn/utils/multi_threading.h" + +namespace cinn { +namespace auto_schedule { + +ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads) + : builder_(builder), runner_(runner), num_threads_(num_threads) {} + +std::vector ScheduleMeasurer::Measure(const std::vector& inputs) { + if (inputs.empty()) { + LOG(WARNING) << "inputs is empty"; + return {}; + } + std::vector build_results(inputs.size()); + std::vector results(inputs.size()); + + // define how to build a candidate with the specified index + auto build_fn = [builder = builder_, &inputs, &build_results, &results](int index) { + VLOG(6) << "Build candidate index: " << index; + auto m_start = std::chrono::steady_clock::now(); + try { + build_results[index] = builder->Build(inputs[index]); + } catch (std::exception& e) { + results[index].error_msg = utils::StringFormat("Build failed, error: %s\n", e.what()); + } + auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_start); + results[index].elapsed_time += static_cast(time_span.count()); + }; + + // define how to run a candidate with the specified index + auto run_fn = [runner = runner_, &inputs, &build_results, &results](int index) { + VLOG(6) << "Run candidate index: " << index; + auto m_start = std::chrono::steady_clock::now(); + try { + // if error occurred in building, then skip running + if (results[index].error_msg.empty()) { + results[index] = runner->Run(inputs[index], build_results[index]); + } + } catch (std::exception& e) { + results[index].error_msg = utils::StringFormat("Run failed, error: %s\n", e.what()); + } + auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - m_start); + results[index].elapsed_time += static_cast(time_span.count()); + }; + + // measure a candidate by calling build and run successively + auto measure_fn = [&build_fn, &run_fn](int index) { + build_fn(index); + run_fn(index); + }; + // default num_threads_ is 1 and in that case it will perform all measurements sequentially inplace. + utils::parallel_run(measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_); + + VLOG(4) << "Measure " << inputs.size() << " candidates"; + return results; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/schedule_measurer.h b/paddle/cinn/auto_schedule/measure/schedule_measurer.h new file mode 100644 index 0000000000000..bf093b2c199a5 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/schedule_measurer.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/measure/measure.h" + +namespace cinn { +namespace auto_schedule { + +// Entrance of schedule measurement, it mainly includes two processes: +// which are building the input schedules and running the generated codes. +class ScheduleMeasurer { + public: + ScheduleMeasurer(ScheduleBuilder* builder, ScheduleRunner* runner, int num_threads = 1); + + // Measure a batch of inputs and return all results once. + std::vector Measure(const std::vector& inputs); + + private: + // The handle to implemented ScheduleBuilder + ScheduleBuilder* builder_; + // The handle to implemented ScheduleRunner + ScheduleRunner* runner_; + // The number of threads used to perform measurement, + // if it is greater than 1 that means parallel measurement. + const int num_threads_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/simple_builder.cc b/paddle/cinn/auto_schedule/measure/simple_builder.cc new file mode 100644 index 0000000000000..5921d1b63b026 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/simple_builder.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/measure/simple_builder.h" + +namespace cinn { +namespace auto_schedule { + +using hlir::framework::GraphCompiler; + +SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler) : graph_compiler_(graph_compiler) {} + +BuildResult SimpleBuilder::Build(const MeasureInput& input) { + CHECK_NE(graph_compiler_, static_cast(nullptr)) << "empty handle to GraphCompiler"; + GraphCompiler::CompileOptions compile_options; + compile_options.groups.emplace_back(input.task->subgraph); + compile_options.lowered_funcs.emplace_back(input.lowered_funcs); + compile_options.remove_unused_variables = false; + VLOG(5) << "call GraphCompiler to Build with Graph::Group size=" << compile_options.groups.size() + << ", lowered_funcs group size=" << compile_options.lowered_funcs.size(); + GraphCompiler::CompilationResult compiled_result = graph_compiler_->Build(compile_options); + + BuildResult build_result; + build_result.compiled_scope = graph_compiler_->GetScope().get(); + build_result.runtime_program = std::move(compiled_result.runtime_program); + return build_result; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/simple_builder.h b/paddle/cinn/auto_schedule/measure/simple_builder.h new file mode 100644 index 0000000000000..8757a3e322207 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/simple_builder.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/measure/measure.h" +#include "cinn/hlir/framework/graph_compiler.h" + +namespace cinn { +namespace auto_schedule { + +// This class utilize the GraphCompiler bound to the graph to build +// the input schedule as executable objects +class SimpleBuilder : public ScheduleBuilder { + public: + SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler); + + // Build and pack the result + BuildResult Build(const MeasureInput& input) override; + + private: + hlir::framework::GraphCompiler* graph_compiler_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/simple_runner.cc b/paddle/cinn/auto_schedule/measure/simple_runner.cc new file mode 100644 index 0000000000000..54660ccc93c56 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/simple_runner.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/measure/simple_runner.h" + +#include +#include +#include +#include +#include +#include + +#include "cinn/common/target.h" +#include "cinn/hlir/framework/buffer.h" +#include "cinn/hlir/framework/scope.h" +#include "cinn/hlir/framework/tensor.h" + +namespace cinn { +namespace auto_schedule { + +using hlir::framework::Buffer; +using hlir::framework::Shape; +using hlir::framework::Tensor; + +// Parameters that needs to be initialized to 0. +// Key is the Op name, and value is the index of the input parameter in the Op. +static const std::unordered_map> kInitWithZeroParams = { + {"lookup_table", {1}}, + {"gather", {1}}, + {"gather_nd", {1}}, + {"scatter_assign", {2}}, + {"scatter_add", {2}}, +}; + +// Generate random value and populate them to the output address of memory +static void PopulateRandomValue(const common::Type& type, const int numel, void* raw_ptr) { + std::random_device seed; + std::default_random_engine engine(seed()); + + if (type == common::Bool()) { + auto* fmt_ptr = reinterpret_cast(raw_ptr); + std::bernoulli_distribution dist(0.5); + std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + } else if (type == common::I32()) { + auto* fmt_ptr = reinterpret_cast(raw_ptr); + std::uniform_int_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); + std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + } else if (type == common::I64()) { + auto* fmt_ptr = reinterpret_cast(raw_ptr); + std::uniform_int_distribution dist(std::numeric_limits::min(), + std::numeric_limits::max()); + std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + } else if (type == common::F32()) { + auto* fmt_ptr = reinterpret_cast(raw_ptr); + std::uniform_real_distribution dist(std::numeric_limits::min(), std::numeric_limits::max()); + std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + } else { + CHECK_EQ(type.bytes(), 8) << "Unsupported type: " << type << ", type.bytes = " << type.bytes(); + auto* fmt_ptr = reinterpret_cast(raw_ptr); + std::uniform_int_distribution dist(std::numeric_limits::min(), + std::numeric_limits::max()); + std::generate_n(fmt_ptr, numel, [&engine, &dist]() { return dist(engine); }); + } +} + +// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize the tensor with random value. +static void InitTensorData(Tensor tensor, const common::Target& target, bool init_with_zero) { + int mem_size = tensor->shape().numel() * tensor->type().bytes(); + auto* tensor_data = tensor->mutable_data(target, tensor->type()); +#ifdef CINN_WITH_CUDA + if (target == common::DefaultNVGPUTarget()) { + if (init_with_zero) { + cudaMemset(tensor_data, 0, mem_size); + } else { + void* tmp_buffer = malloc(mem_size); + PopulateRandomValue(tensor->type(), tensor->shape().numel(), tmp_buffer); + cudaMemcpy(tensor_data, tmp_buffer, mem_size, cudaMemcpyHostToDevice); + free(tmp_buffer); + } + } +#endif + if (target == common::DefaultHostTarget()) { + if (init_with_zero) { + memset(tensor_data, 0, mem_size); + } else { + PopulateRandomValue(tensor->type(), tensor->shape().numel(), tensor_data); + } + } +} + +// Find all parameter names in the task corresponding to the MeasureInput +// that need to be initialized to 0 when measuring. +static std::unordered_set ParamsNeedInitWithZero(const MeasureInput& input) { + std::unordered_set res; + std::vector nodes = input.task->subgraph->CollectNodes(); + for (auto* node : nodes) { + if (kInitWithZeroParams.count(node->op()->name) != 0) { + std::vector param_idxs = kInitWithZeroParams.at(node->op()->name); + const auto& inlinks = node->inlinks_in_order(); + for (int param_idx : param_idxs) { + CHECK_GT(inlinks.size(), param_idx); + auto& edge = inlinks.at(param_idx); + std::string param_name = edge->source()->as()->id(); + VLOG(6) << "param needs to be init with 0: " << param_name; + res.insert(param_name); + } + } + } + + return res; +} + +SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) { + CHECK_GT(repeat_times_, 0) << "repeat_times can't less than 0"; +} + +// Prepare execution arguments of all instructions to run, a argument +// may be obtained from the input of measurement or allocating new buffer +// with random value. +std::map SimpleRunner::PrepareArgs(const MeasureInput& input, + const BuildResult& build_result, + hlir::framework::Scope* temp_scope) { + std::map result; + + const auto& target = input.task->target; + const auto* input_args = input.execution_args; + const auto* compiled_scope = build_result.compiled_scope; + const auto& instructions = build_result.runtime_program->GetRunInstructions(); + + std::unordered_set params_need_init_with_zero = ParamsNeedInitWithZero(input); + + auto fill_arg_fn = [&](const std::string& param) { + VLOG(6) << "Filling argument:" << param; + // the argument is duplicated and has been prepared. + if (result.count(param)) { + return; + } + + // if the input of measurement specifies this argument, + // we should use it firstly. + if (input_args && input_args->count(param)) { + VLOG(6) << "Argument[" << param << "] use input value"; + result.emplace(param, input_args->at(param)); + return; + } + + if (temp_scope->FindVar(param)) { + auto temp_tensor = temp_scope->GetTensor(param); + result.emplace(param, temp_tensor->buffer()); + return; + } + + // allocate a new buffer for this argument and store it in + // the temporary scope to be released at proper time. + auto compiled_tensor = compiled_scope->GetTensor(param); + temp_scope->Var(param); + auto temp_tensor = temp_scope->GetTensor(param); + temp_tensor->Resize(compiled_tensor->shape()); + temp_tensor->set_type(compiled_tensor->type()); + temp_tensor->mutable_data(target, compiled_tensor->type()); + InitTensorData(temp_tensor, target, params_need_init_with_zero.count(param) != 0); + + result.emplace(param, temp_tensor->buffer()); + }; + + for (auto&& instr : instructions) { + for (auto&& args : instr->GetInArgs()) { + std::for_each(args.begin(), args.end(), fill_arg_fn); + } + + for (auto&& args : instr->GetOutArgs()) { + std::for_each(args.begin(), args.end(), fill_arg_fn); + } + } + return result; +} + +MeasureResult SimpleRunner::Run(const MeasureInput& input, const BuildResult& build_result) { + MeasureResult result; + auto t_start = std::chrono::steady_clock::now(); + // prepare execution arguments + VLOG(4) << "SimpleRunner prepare execution arguments"; + hlir::framework::Scope temp_scope; // used for store temporary allocated data + auto execution_args = PrepareArgs(input, build_result, &temp_scope); + + // Execute each instruction repeatedly and take the average as cost. + result.execution_cost = 0; + const auto& instructions = build_result.runtime_program->GetRunInstructions(); + for (auto ct = 0; ct < instructions.size(); ++ct) { + auto&& instr = instructions.at(ct); + VLOG(5) << "Start running instruction-" << ct; + auto run_start = std::chrono::steady_clock::now(); + for (int i = 0; i < repeat_times_; ++i) { + instr->Run(&execution_args); + } +#ifdef CINN_WITH_CUDA + if (instr->target_ == common::DefaultNVGPUTarget()) { + CUDA_CALL(cudaDeviceSynchronize()); + } +#endif + auto time_span = + std::chrono::duration_cast(std::chrono::steady_clock::now() - run_start); + auto cost_avg = static_cast(time_span.count()) / repeat_times_; + result.execution_cost += cost_avg; + } + + auto time_span = std::chrono::duration_cast(std::chrono::steady_clock::now() - t_start); + result.elapsed_time = static_cast(time_span.count()); + + VLOG(4) << "A measurement done:repeat_times[" << repeat_times_ << "]total_elapsed_time[" << result.elapsed_time + << "]us,execution_cost[" << result.execution_cost << "]us"; + return result; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/simple_runner.h b/paddle/cinn/auto_schedule/measure/simple_runner.h new file mode 100644 index 0000000000000..48b316a0d7c06 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/simple_runner.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/measure/measure.h" +#include "cinn/hlir/framework/instruction.h" + +namespace cinn { +namespace auto_schedule { + +// This class utilize the built instructions to execute the generated +// kernels and count the elapsed time as the measurement of performance +class SimpleRunner : public ScheduleRunner { + public: + SimpleRunner(int repeat_times); + + MeasureResult Run(const MeasureInput& input, const BuildResult& build_result) override; + + private: + std::map PrepareArgs(const MeasureInput& input, + const BuildResult& build_result, + hlir::framework::Scope* temp_scope); + + private: + // The repeat times of running instructions, + // this runner will return the average time + const int repeat_times_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/measure/simple_runner_test.cc b/paddle/cinn/auto_schedule/measure/simple_runner_test.cc new file mode 100644 index 0000000000000..b20faa6734a52 --- /dev/null +++ b/paddle/cinn/auto_schedule/measure/simple_runner_test.cc @@ -0,0 +1,139 @@ + +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/measure/simple_runner.h" + +#include + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/optimize.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph_compiler.h" + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::BuildScope; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::GraphCompiler; +using ::cinn::hlir::framework::Instruction; +using ::cinn::hlir::framework::Scope; + +class TestSimpleRunner : public ::testing::Test { + public: +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + std::shared_ptr graph; + std::shared_ptr compiled_scope; + std::unique_ptr graph_compiler; + std::unique_ptr task; + + MeasureInput input; + BuildResult build_result; + + static frontend::Program CreateAddReluProgram(); + void SetUp() override { + std::unordered_set fetch_ids; + auto program = CreateAddReluProgram(); + auto graph = cinn::frontend::Optimize(&program, fetch_ids, target); + compiled_scope = BuildScope(target, graph); + graph_compiler = std::make_unique(target, compiled_scope, graph); + auto runtime_program = graph_compiler->Build(); + const auto& instructions = runtime_program->GetRunInstructions(); + ASSERT_EQ(1, instructions.size()); + + build_result.compiled_scope = compiled_scope.get(); + build_result.runtime_program = std::move(runtime_program); + + task = std::make_unique(); +#ifdef CINN_WITH_CUDA + task->target = common::DefaultNVGPUTarget(); +#else + task->target = common::DefaultHostTarget(); +#endif + task->subgraph = graph->fusion_groups.front(); + input.task = task.get(); + } +}; + +frontend::Program TestSimpleRunner::CreateAddReluProgram() { + constexpr int M = 32; + constexpr int N = 24; + frontend::NetBuilder builder("test"); + + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Relu(c); + return builder.Build(); +} + +TEST_F(TestSimpleRunner, MeasureWithRandomValue) { + auto runner = std::make_unique(1); + ASSERT_NO_THROW(runner->Run(input, build_result)); +} + +TEST_F(TestSimpleRunner, MeasureWithSpecifiedArgs) { + auto ta = compiled_scope->GetTensor("A"); + ta->mutable_data(target); + auto tb = compiled_scope->GetTensor("B"); + tb->mutable_data(target); + std::map preset_args; + preset_args.emplace("A", ta->buffer()); + preset_args.emplace("B", tb->buffer()); + + auto runner = std::make_unique(1); + // specific several execution args + input.execution_args = &preset_args; + ASSERT_NO_THROW(runner->Run(input, build_result)); +} + +TEST_F(TestSimpleRunner, TimeMeasured) { + // set up a BuildResult object with one instruction of the `sleep` function + void (*sleep_fn)(void*, int32_t) = [](void*, int32_t) -> void { + std::this_thread::sleep_for(std::chrono::microseconds(100)); + }; + BuildResult build_result; + build_result.compiled_scope = nullptr; + std::vector> instructions; + instructions.emplace_back( + new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn")); + instructions.back()->SetLoweredFunc(reinterpret_cast(sleep_fn)); + instructions.back()->Finalize(); + build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions))); + + // to skip the condition check of params in Instruction::PreparePodArgs + std::map preset_args; + preset_args.emplace("empty_placeholder", cinn_pod_value_t()); + input.execution_args = &preset_args; + + auto runner = std::make_unique(2); + MeasureResult measure_result = runner->Run(input, build_result); + // because the kernel function will sleep 100 us, + // the cost time of execution and span in total must + // be greater than 100us and 200us (repeatedly running 2 times) respectively. + ASSERT_GE(measure_result.execution_cost, 100); + ASSERT_GE(measure_result.elapsed_time, 200); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt b/paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt new file mode 100644 index 0000000000000..eda51bbb7e568 --- /dev/null +++ b/paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt @@ -0,0 +1,9 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + cooperative_process.cc + ) + +if (WITH_CUDA) + nv_test(test_cooperative_process SRCS cooperative_process_test.cc DEPS cinncore auto_gen_rule_test_helper test_program_builder) +endif() diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc new file mode 100644 index 0000000000000..2b8c05e105f1d --- /dev/null +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/post_schedule_rule/cooperative_process.h" + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/schedule_desc.h" + +namespace cinn { +namespace auto_schedule { + +int ExtractNumThreads(const ir::IRSchedule& ir_schedule, const std::string& bind_axis) { + const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc(); + for (auto&& step : trace.Steps()) { + if (step.type == "Bind" && step.attrs.find("thread_axis") != step.attrs.end() && + absl::get(step.attrs.at("thread_axis")) == bind_axis) { + CHECK_EQ(step.inputs.at("loop").size(), 1); + return step.inputs.at("loop")[0].As()->extent.as_int32(); + } + } + return 0; +} + +std::vector FindCandidates(const ir::ScheduleDesc& trace) { + std::vector candidate_block_names; + for (auto&& step : trace.Steps()) { + if (step.type == "AnnotateIntAttr" && + absl::get(step.attrs.at("key")) == ir::attr::cooperative_process) { + candidate_block_names.push_back( + step.inputs.at("block")[0].As()->schedule_block.As()->name); + } + } + return candidate_block_names; +} + +bool CooperativeProcess::Apply(ir::IRSchedule* schedule) { + int num_threads = ExtractNumThreads(*schedule, "threadIdx.x"); + const ir::ScheduleDesc& trace = schedule->GetTraceDesc(); + std::vector candidate_block_names = FindCandidates(trace); + for (auto&& candidate : candidate_block_names) { + auto loop = schedule->GetLoops(candidate).back(); + if (loop.As()->extent.as_int32() <= num_threads) { + schedule->Bind(loop, "threadIdx.x"); + loop = schedule->GetLoops(candidate).back(); + schedule->SyncThreads(loop); + } else { + auto splited_buffer_loop = schedule->Split(loop, {-1, num_threads}); + schedule->Bind(splited_buffer_loop.back(), "threadIdx.x"); + schedule->SyncThreads(splited_buffer_loop[0]); + } + auto block = schedule->GetBlock(candidate); + schedule->Unannotate(block, ir::attr::cooperative_process); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h new file mode 100644 index 0000000000000..9f106dfda0eb3 --- /dev/null +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h" + +namespace cinn { +namespace auto_schedule { + +/* + * @brief Rewrite the cooperative_process annotation to actually bind the loop on threadIdx. + * This rule is used for collaborative data handling of multiple threads within the same block. + */ +class CooperativeProcess : public PostScheduleRule { + public: + CooperativeProcess() = default; + + bool Apply(ir::IRSchedule* schedule) final; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc new file mode 100644 index 0000000000000..c10005a910969 --- /dev/null +++ b/paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/post_schedule_rule/cooperative_process.h" + +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" +#include "cinn/ir/ir_printer.h" +#include "tests/program_builder.h" + +namespace cinn { +namespace auto_schedule { + +class TestCooperativeProcess : public TestAutoGenRuleBase { + public: + int fixed_rand_seed = 1; + std::vector default_input_names; + std::vector default_output_names; +}; + +TEST_F(TestCooperativeProcess, Matmul) { + default_input_names = {"X", "Y"}; + default_output_names = {"temp_matmul_out"}; + std::vector X_shape = {32, 32}; + std::vector Y_shape = {32, 32}; + std::vector out_shape = {32, 32}; + + int num_blocks_y = 2; + int num_blocks_x = 2; + int num_threads_y = 8; + int num_threads_x = 2; + int steps_k = 8; + + Initialize(common::DefaultNVGPUTarget()); + frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); + + // split loops + std::vector loops = ir_schedule.GetLoops("temp_matmul_out"); + std::vector k_loops = ir_schedule.Split(loops[2], {steps_k, -1}); + std::vector j_loops = ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1}); + std::vector i_loops = ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1}); + // reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2 + loops = ir_schedule.GetLoops("temp_matmul_out"); + ir_schedule.Reorder({loops[0], loops[3], loops[1], loops[4], loops[6], loops[7], loops[2], loops[5]}); + // fuse and bind + loops = ir_schedule.GetLoops("temp_matmul_out"); + ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]}); + ir::Expr i0_j0_fused = ir_schedule.Fuse({loops[0], loops[1]}); + loops = ir_schedule.GetLoops("temp_matmul_out"); + ir_schedule.Bind(loops[1], "threadIdx.x"); + ir_schedule.Bind(loops[0], "blockIdx.x"); + // cache read + ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out"); + ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared"); + std::string X_cache_block_name = + X_cache_block.As()->schedule_block.As()->name; + loops = ir_schedule.GetLoops("temp_matmul_out"); + ir_schedule.ComputeAt(X_cache_block, loops[2]); + std::vector X_cache_loops = ir_schedule.GetLoops(X_cache_block_name); + ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]}); + ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name), ir::attr::cooperative_process, 0); + + out_block = ir_schedule.GetBlock("temp_matmul_out"); + ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared"); + std::string Y_cache_block_name = + Y_cache_block.As()->schedule_block.As()->name; + loops = ir_schedule.GetLoops("temp_matmul_out"); + ir_schedule.ComputeAt(Y_cache_block, loops[2]); + std::vector Y_cache_loops = ir_schedule.GetLoops(Y_cache_block_name); + ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]}); + ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name), ir::attr::cooperative_process, 0); + + // apply CooperativeProcess + CooperativeProcess cooperative_process; + cooperative_process.Apply(&ir_schedule); + + // check ir + auto ir = GetIR(ir_schedule); + VLOG(6) << "after CooperativeProcess, ir: \n" << ir; + std::string expected_ir = R"ROC(Expr 0 { +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 2) + { + serial for (j, 0, 2) + { + serial for (i_0, 0, 8) + { + serial for (j_0, 0, 2) + { + serial for (i_1, 0, 2) + { + serial for (j_1, 0, 8) + { + ScheduleBlock(temp_matmul_out__reduce_init) + { + i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))) + { + temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f + } + } + } + } + } + } + } + } + thread_bind[blockIdx.x] for (i_j_fused, 0, 4) + { + thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 16) + { + serial for (reduce_k_0, 0, 8) + { + serial for (ax0_0_ax1_0_fused, 0, 2) + { + thread_bind[threadIdx.x] for (ax0_0_ax1_0_fused_0, 0, 16) + { + ScheduleBlock(Y_reshape_shared_temp_buffer) + { + v0, v1 = axis.bind(((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) / 8) + (4 * reduce_k_0)), ((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) % 8) + ((8 * (i_0_j_0_fused % 2)) + (16 * (i_j_fused % 2))))) + attrs(compute_at_extra_var:ax0_0,ax1_0) + { + Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1] + } + } + } + } + __syncthreads() + thread_bind[threadIdx.x] for (ax0_ax1_fused, 0, 8) + { + ScheduleBlock(X_reshape_shared_temp_buffer) + { + v0, v1 = axis.bind(((ax0_ax1_fused / 4) + ((2 * (i_0_j_0_fused / 2)) + (16 * (i_j_fused / 2)))), ((ax0_ax1_fused % 4) + (4 * reduce_k_0))) + attrs(compute_at_extra_var:ax0,ax1) + { + X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1] + } + } + } + __syncthreads() + serial for (reduce_k_1, 0, 4) + { + serial for (i_1, 0, 2) + { + serial for (j_1, 0, 8) + { + ScheduleBlock(temp_matmul_out) + { + i0_0, i1_0, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1)) + { + temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))])) + } + } + } + } + } + } + } + } + } + } +} +} // end Expr 0 +)ROC"; + ASSERT_EQ(ir, expected_ir); + + // build ir::Module and debug source code + auto ir_module = BuildIRModule(ir_schedule); + auto source_code = GenSourceCode(ir_module); + VLOG(6) << "scheduled source code:\n" << source_code; + + // execute and check precision + CheckResult( + GenExecutableKernel(ir_module), + GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), + default_input_names, + default_output_names, + {X_shape, Y_shape}, + {out_shape}, + target_); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h b/paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h new file mode 100644 index 0000000000000..136d4fc18f297 --- /dev/null +++ b/paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * Base class for rules of post process, + * used to process schedules that rely on mutate results. + */ +class PostScheduleRule { + public: + PostScheduleRule() = default; + + /** + * @brief Apply the post schedule rule to the given SearchState. + * @param state The given SearchState for post schedule. + * @return True if apply successfully. + */ + virtual bool Apply(ir::IRSchedule* schedule) = 0; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/CMakeLists.txt b/paddle/cinn/auto_schedule/search_space/CMakeLists.txt new file mode 100644 index 0000000000000..44d73649efaec --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/CMakeLists.txt @@ -0,0 +1,15 @@ +add_subdirectory(auto_gen_rule) + +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + search_space.cc + search_state.cc + block_sampler.cc + rule_sampler.cc + ) + +cc_test(test_search_space SRCS search_space_test.cc DEPS cinncore) +cc_test(test_search_state SRCS search_state_test.cc DEPS cinncore) +cc_test(test_block_sampler SRCS block_sampler_test.cc DEPS cinncore) +cc_test(test_rule_sampler SRCS rule_sampler_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt new file mode 100644 index 0000000000000..dcb81c71baefd --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt @@ -0,0 +1,24 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + auto_gen_rule.cc + auto_inline.cc + auto_unroll.cc + multi_level_tiling.cc + skip_rule.cc + auto_bind.cc +) + +if (WITH_TESTING) + cc_library(auto_gen_rule_test_helper SRCS test_helper.cc DEPS glog gtest cinncore) +endif() + +if (WITH_CUDA) + nv_test(test_mix_rules SRCS mix_rules_test.cc DEPS cinncore auto_gen_rule_test_helper test_program_builder) + nv_test(test_auto_bind SRCS auto_bind_test.cc DEPS cinncore auto_gen_rule_test_helper test_program_builder) + nv_test(test_multi_level_tiling SRCS multi_level_tiling_test.cc DEPS cinncore auto_gen_rule_test_helper test_program_builder) +endif() + +#cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper) +cc_test(test_skip_rule SRCS skip_rule_test.cc DEPS cinncore) +cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc new file mode 100644 index 0000000000000..0a49d8c269645 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h" + +#include + +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +static constexpr uint32_t kMaxBlocks = 256; +// check whether the input ir::For is a spatial loop +bool IsSpatialLoop(const ir::For* for_node) { + if (for_node->for_type() != ir::ForType::Serial) return false; + const auto& loop_var = for_node->loop_var; + // collect cases where the loop_var used in one of reduce axis in underneath ScheduleBlock + auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(for_node->body, [&loop_var](const Expr* x) { + const auto* block_realize = x->As(); + if (!block_realize) return false; + + const auto* schedule_block = block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; + CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); + for (int i = 0; i < block_realize->iter_values.size(); ++i) { + const ir::Var& iter_var = schedule_block->iter_vars[i]; + const ir::Expr& binding = block_realize->iter_values[i]; + if (iter_var->is_reduce_axis || iter_var->name.substr(0, 6) == "reduce") { + auto used_exprs = ir::CollectIRNodesWithoutTensor(binding, [&loop_var](const Expr* x) { + const ir::_Var_* var = x->As(); + if (var && (x->same_as(loop_var) || var->name == loop_var->name)) { + return true; + } + return false; + }); + if (!used_exprs.empty()) return true; + } + } + + return false; + }); + + if (!used_for_reduce_axis.empty()) return false; + return true; +} + +// count the number of loops that can be binded from the input for_node to bottom +int CountLoopCanBinded(const ir::For* for_node) { + int cnt = 0; + while (for_node) { + if (for_node->is_binded()) break; // has binded + if (!IsSpatialLoop(for_node)) break; // only spatial loops to be binded + + cnt += 1; + + CHECK(for_node->body.defined() && for_node->body.As()) << "Body is not defined"; + const ir::Block* body = for_node->body.As(); + // terminate when body of this loop has more than one statement or the body is not a ir::For node + for_node = body->stmts.size() == 1 ? body->stmts[0].As() : nullptr; + } + return cnt; +} + +void BindGPUIndex(ir::IRSchedule* ir_schedule, + const std::string& block_name, + int num_loops_to_bind, + int max_blocks, + int max_threads_per_block) { + auto all_loops = ir_schedule->GetLoops(block_name); + CHECK_LE(num_loops_to_bind, all_loops.size()) << "The number of loops to be bind is greater than size of all_loops"; + // check whether it is the case that threadIdx has been binded but blockIdx not, + // the threadIdx can only be binded in the first loop after num_loops_to_bind loops + // because we has excluded other cases in CountLoopCanBinded + bool gpu_thread_has_binded = + num_loops_to_bind < all_loops.size() && all_loops[num_loops_to_bind].As()->is_gpu_thread_binded(); + Expr fused_loop = ir_schedule->Fuse({all_loops.begin(), all_loops.begin() + num_loops_to_bind}); + int32_t extent = fused_loop.As()->extent.as_int32(); + if (gpu_thread_has_binded) { + ir_schedule->Bind(fused_loop, "blockIdx.x"); + return; + } + + if (extent <= max_threads_per_block) { + ir_schedule->Bind(fused_loop, "threadIdx.x"); + return; + } + + if (extent <= max_blocks * max_threads_per_block) { + auto splits = ir_schedule->Split(fused_loop, {-1, max_threads_per_block}); + CHECK_EQ(splits.size(), 2); + ir_schedule->Bind(splits[0], "blockIdx.x"); + ir_schedule->Bind(splits[1], "threadIdx.x"); + } else { + auto splits = ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block}); + CHECK_EQ(splits.size(), 3); + ir_schedule->Reorder({splits[1], splits[2], splits[0]}); + all_loops = ir_schedule->GetLoops(block_name); + ir_schedule->Bind(all_loops[0], "blockIdx.x"); + ir_schedule->Bind(all_loops[1], "threadIdx.x"); + } +} + +RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) { + ir_schedule_ = ir_schedule; + + for (auto&& block_realize : ir_schedule->GetAllBlocks()) { + auto all_loops = ir_schedule->GetLoops(block_realize); + if (CountLoopCanBinded(all_loops[0].As()) > 0) { + applicable_schedule_blocks_.emplace_back(block_realize); + } + } + num_applicable_ = applicable_schedule_blocks_.size(); + VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +void AutoBind::Apply(int index) { + CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; + auto applied_block = applicable_schedule_blocks_.at(index); + auto all_loops = ir_schedule_->GetLoops(applied_block); + BindGPUIndex(ir_schedule_, + applied_block.As()->schedule_block.As()->name, + CountLoopCanBinded(all_loops[0].As()), + kMaxBlocks, + target_->max_num_threads()); + return; +} + +RuleApplyType AutoBind::AnalyseApplyType(SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); + auto all_loops = state->ir_schedule.GetLoops(block_expr); + return CountLoopCanBinded(all_loops[0].As()) > 0 ? RuleApplyType::kApplyAndPruneOtherRules + : RuleApplyType::kCannotApply; +} + +std::vector AutoBind::ApplyOnBlock(SearchState state, const std::string& block_name) { + SearchState new_state = state.Copy(); + auto all_loops = state->ir_schedule.GetLoops(block_name); + BindGPUIndex(&new_state->ir_schedule, + block_name, + CountLoopCanBinded(all_loops[0].As()), + kMaxBlocks, + target_->max_num_threads()); + return {new_state}; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h new file mode 100644 index 0000000000000..b93f633b230e3 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +// Auto bind GPU index(BlockIdx, ThreadIdx) to the loops around the block +class AutoBind : public AutoGenRule { + public: + AutoBind(const common::Target& target) : AutoGenRule(target) {} + ~AutoBind() = default; + + RuleApplyType Init(ir::IRSchedule* init_schedule) override; + + void Apply(int index) override; + + std::string GetRuleName() const override { return "AutoBind"; } + + RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + + std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + + private: + std::vector applicable_schedule_blocks_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc new file mode 100644 index 0000000000000..9ffbe0a3f4a3a --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h" + +#include +#include + +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" +#include "cinn/ir/ir_printer.h" +#include "tests/program_builder.h" + +namespace cinn { +namespace auto_schedule { + +static constexpr uint32_t kMaxBlocks = 256; +static constexpr uint32_t kMaxThreadsPerBlock = 1024; + +class TestAutoBind : public TestAutoGenRuleBase { + public: + std::vector default_input_names = {"X", "Y"}; + std::vector default_output_names = {"temp_matmul_out"}; + + void TestApplyOnElementWiseAdd(const std::vector& shape, const std::string& block_name) { + Initialize(common::DefaultNVGPUTarget()); + auto test_program = tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}}); + // construct input parameter + ir::IRSchedule ir_schedule = MakeIRSchedule(test_program); + SearchState state(ir_schedule, 0, {}); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // apply + AutoBind auto_bind(target_); + ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name), RuleApplyType::kApplyAndPruneOtherRules); + auto result = auto_bind.ApplyOnBlock(state, block_name)[0]; + std::vector exprs = result->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + VLOG(6) << "AutoBind applied Expr: " << exprs[0]; + + // check bind result + auto all_loops = result->ir_schedule.GetLoops(block_name); + int total_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + if (total_num <= kMaxThreadsPerBlock) { + ASSERT_EQ(all_loops.size(), 1); + EXPECT_EQ(all_loops[0].As()->extent.as_int32(), total_num); + EXPECT_TRUE(all_loops[0].As()->is_gpu_thread_binded()); + } else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) { + ASSERT_EQ(all_loops.size(), 2); + EXPECT_EQ(all_loops[0].As()->extent.as_int32(), + static_cast(std::ceil(double(total_num) / kMaxThreadsPerBlock))); + EXPECT_TRUE(all_loops[0].As()->is_gpu_block_binded()); + EXPECT_EQ(all_loops[1].As()->extent.as_int32(), kMaxThreadsPerBlock); + EXPECT_TRUE(all_loops[1].As()->is_gpu_thread_binded()); + } else { + ASSERT_EQ(all_loops.size(), 3); + EXPECT_EQ(all_loops[0].As()->extent.as_int32(), kMaxBlocks); + EXPECT_TRUE(all_loops[0].As()->is_gpu_block_binded()); + EXPECT_EQ(all_loops[1].As()->extent.as_int32(), kMaxThreadsPerBlock); + EXPECT_TRUE(all_loops[1].As()->is_gpu_thread_binded()); + EXPECT_EQ(all_loops[2].As()->extent.as_int32(), + static_cast(std::ceil(double(total_num) / (kMaxBlocks * kMaxThreadsPerBlock)))); + EXPECT_FALSE(all_loops[2].As()->is_binded()); + } + + // build and run + auto ir_module = BuildIRModule(result->ir_schedule); + auto source_code = GenSourceCode(ir_module); + VLOG(6) << "Optimized source code:\n" << source_code; + auto manual_ir_module = BuildIRModule(MakeIRSchedule(test_program, /* apply_manual_schedule*/ true)); + VLOG(6) << "Manual-schedule compiled source code:\n" << GenSourceCode(manual_ir_module); + CheckResult(GenExecutableKernel(ir_module), + GenExecutableKernel(manual_ir_module), + default_input_names, + {block_name}, + {shape, shape}, + {shape}, + target_); + } +}; + +TEST_F(TestAutoBind, AnalyseApplyType) { + Initialize(common::DefaultNVGPUTarget()); + ir::IRSchedule ir_schedule = MakeIRSchedule(tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}})); + SearchState state(ir_schedule, 0, {}); + AutoBind auto_bind(target_); + const std::string& applied_block_name = default_output_names.back(); + // outer two loops of initial Expr are spatial loops, so it can be applied + EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kApplyAndPruneOtherRules); + state->ir_schedule.Fuse(applied_block_name, {0, 1}); + state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0], "threadIdx.x"); + // after fuse and bind, there is no loops to be binded. + EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name), RuleApplyType::kCannotApply); +} + +TEST_F(TestAutoBind, ApplyOnBlock) { + TestApplyOnElementWiseAdd({64, 128}, "var_1"); + TestApplyOnElementWiseAdd({57, 133, 125}, "var_1"); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc new file mode 100644 index 0000000000000..fb6eaa797b4c1 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" + +#include + +#include + +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {} + +int AutoGenRule::NumberApplicable() const { + CHECK_GE(num_applicable_, 0) << "Call " << GetRuleName() << "::NumberApplicable() without initialization."; + return num_applicable_; +} + +void AutoGenRule::ApplyRandomly() { + CHECK_GT(num_applicable_, 0) << "Call " << GetRuleName() << "::ApplyRandomly() with NumberApplicable() == 0"; + int index = rand() % num_applicable_; + return Apply(index); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h new file mode 100644 index 0000000000000..2a4ed201ad709 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h @@ -0,0 +1,84 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { +/** + * Enum class representing how this rule can be applied to a ModuleExpr. + */ +enum class RuleApplyType : int { + // This rule cannot be applied to ModuleExpr. + kCannotApply = 0, + // This rule can be applied to ModuleExpr, + // and the original ModuleExpr will be retained for branching with other rules. + kApply = 1, + // This rule can be applied, but the original ModuleExpr will be deleted, + // so the branches with other rules applied on the original ModuleExpr will be pruned. + kApplyAndPruneOtherRules = 2, +}; + +/** + * Base class for rules of auto-generating schedule (like Ansor's sketch generation) + * + */ +class AutoGenRule { + public: + AutoGenRule(const common::Target& target); + ~AutoGenRule() = default; + + // Initialize the AutoGenRule, it must be called before further actions. + // Returns false if the rule cannot be applied on the mod_expr, true otherwise. + virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0; + + // CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so + // a auto gen rule may be suitable to different number of + // Schedule Blocks. This method returns the number of ScheduleBlock + // that can be applied by this auto gen rule + virtual int NumberApplicable() const; + + // Applies rule on the ir::ModuleExpr for a schedule block randomly + virtual void ApplyRandomly(); + + // Applies rule on the ir::ModuleExpr for a schedule block specified by index + // between 0 (inclusive) and NumberApplicable() (exclusive) + virtual void Apply(int index) = 0; + + // Returns the name of the rule, used for debug. + virtual std::string GetRuleName() const = 0; + + // Analyze the ApplyType of the rule used for a block determined by a specific SearchState and block name + virtual RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const = 0; + + // Apply the rule to a block determined by a specific SearchState and block name + virtual std::vector ApplyOnBlock(SearchState state, const std::string& block_name) = 0; + + protected: + // number of ScheduleBlock that can apply this auto gen rule + int num_applicable_ = -1; + // Target, not owned. + const common::Target* target_; + // IRSchedule, not owned; + ir::IRSchedule* ir_schedule_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc new file mode 100644 index 0000000000000..5b53ee148173c --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc @@ -0,0 +1,214 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" + +#include +#include +#include +#include +#include +#include + +#include "cinn/auto_schedule/analysis/analyze_ir.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +AutoInline::AutoInline(const common::Target& target, const std::unordered_set& no_inline_output_names) + : AutoGenRule(target), no_inline_output_names_(no_inline_output_names) {} + +bool AutoInline::CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { + const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As(); + const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + ir::Expr compute_body = sche_block->body; + ir::Expr root = ir_sch->GetRootBlock(sche_block_realize_expr); + + // Check the schedule block to be inlined is not a reduce tensor. + std::set find_store = + ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { return x->As(); }); + if (find_store.size() != 1UL) { + return false; + } + + ir::Expr tensor_expr = (*find_store.begin()).As()->tensor; + ir::Tensor tensor = tensor_expr.as_tensor_ref(); + if (tensor->is_reduce_tensor()) { + return false; + } + + // LoweredFunc output can be tensor name or tensor buffer name + if (no_inline_output_names_.find(tensor->name) != no_inline_output_names_.end() || + no_inline_output_names_.find(tensor->buffer->name) != no_inline_output_names_.end()) { + return false; + } + + // write_buffers.size() = 1 and read_buffers is empty, means const + // we can inline to consumer + if (sche_block->read_buffers.empty()) { + return true; + } + + // Check this schedule block is the only writer of the tensor. + find_store = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor->name; + }); + if (find_store.size() != 1UL) { + return false; + } + // Check there is no overlap between the buffers the schedule block reads and writes. + std::set find_load = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor_expr; }); + if (!find_load.empty()) { + return false; + } + + ir::Expr store = *(find_store.begin()); + + ir::ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), store); + if (!inliner.BodyPatternAllowInline()) { + return false; + } + + ir::LeafBlockRemovalPlan remove_plan(sche_block_realize_expr, &inliner.src_stmt, &inliner.tgt_stmt); + remove_plan(&root); + if (!inliner.src_stmt.defined() || !inliner.tgt_stmt.defined()) { + return false; + } + + VLOG(6) << "Found store Expr " << store << ", which CanInlineIntoConsumer"; + return true; +} + +AutoInlineType AutoInline::AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const { + const ir::ScheduleBlockRealize* sche_block_realize = sche_block_realize_expr.As(); + const ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + + // Inline if the block has only 1 write buffer + if (sche_block->write_buffers.size() != 1) { + return AutoInlineType::kCannotInline; + } + + std::unordered_set no_inline_node_types = {ir::IrNodeTy::IfThenElse}; + if (ContainsNodeType(sche_block->body, no_inline_node_types)) { + return AutoInlineType::kCannotInline; + } + + // InlineIntoConsumer other than above situations + if (CanInlineIntoConsumer(sche_block_realize_expr, ir_sch)) { + return AutoInlineType::kInlineIntoConsumer; + } + + // TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now, + // so we just do kInlineIntoConsumer here. Add CanInlineIntoProducer + // once ReverseComputeInline is ready. + return AutoInlineType::kCannotInline; +} + +RuleApplyType AutoInline::Init(ir::IRSchedule* ir_schedule) { + ir_schedule_ = ir_schedule; + all_block_realizes_ = ir_schedule_->GetAllBlocks(); + apply_indices_and_type_.clear(); + num_applicable_ = 0; + + for (size_t i = 0; i < all_block_realizes_.size(); ++i) { + ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); + AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As()); + AutoInlineType type = AnalyzeInlineType(all_block_realizes_[i], ir_schedule_); + if (type != AutoInlineType::kCannotInline) { + ++num_applicable_; + apply_indices_and_type_.push_back({i, type}); + } + } + + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +void AutoInline::Apply(int index) { + CHECK(ir_schedule_ != nullptr) << "Run AutoInline::Apply without Init"; + CHECK(num_applicable_ > 0 && apply_indices_and_type_.size() == num_applicable_) + << "AutoInline::Apply pre-condition doesn't meet"; + CHECK(index >= 0 && num_applicable_ > index) + << "Invalid index for AutoInline::Apply, the index needs 0 <= index && index < NumberApplicable(), " + << "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; + + int apply_index = apply_indices_and_type_[index].first; + Apply(ir_schedule_, all_block_realizes_[apply_index]); + return; +} + +std::string AutoInline::GetRuleName() const { return "AutoInline"; } + +RuleApplyType AutoInline::AnalyseApplyType(SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); + auto* block_realize = block_expr.As(); + CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; + + AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + AutoInlineType type = AnalyzeInlineType(block_expr, &state->ir_schedule); + + return type == AutoInlineType::kCannotInline ? RuleApplyType::kCannotApply : RuleApplyType::kApplyAndPruneOtherRules; +} + +std::vector AutoInline::ApplyOnBlock(SearchState state, const std::string& block_name) { + SearchState new_state = state.Copy(); + Expr block_expr = new_state->ir_schedule.GetBlock(block_name); + Apply(&new_state->ir_schedule, block_expr); + + return {new_state}; +} + +void AutoInline::Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { + auto* block_realize = block_expr.As(); + CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; + + AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + AutoInlineType type = AnalyzeInlineType(block_expr, ir_schedule); + + if (type == AutoInlineType::kInlineIntoConsumer) { + VLOG(6) << "Apply ComputeInline on " << block_expr; + ir_schedule->ComputeInline(block_expr); + VLOG(6) << "After ComputeInline: " << block_expr; + + } else if (type == AutoInlineType::kInlineIntoProducer) { + // TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now, + // so we just do kInlineIntoConsumer here. Add CanInlineIntoConsumer + // once ReverseComputeInline is ready. + + // ir_schedule->ReverseComputeInline(all_block_realizes_[apply_index]); + } + + // Make sure re-apply the AutoInline won't be error. + // AutoInline changes the read and write buffers of schedule blocks, + // we need to re-analyze + all_block_realizes_ = ir_schedule->GetAllBlocks(); + for (size_t i = 0; i < all_block_realizes_.size(); ++i) { + ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); + ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + sche_block->read_buffers = {}; + sche_block->write_buffers = {}; + AnalyzeScheduleBlockReadWriteBuffer(sche_block); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h new file mode 100644 index 0000000000000..982092e717c33 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * The types of the AutoInline + */ +enum class AutoInlineType : int { + // The block cannot be inlined + kCannotInline = 0, + // Inline this block into the consumer + kInlineIntoConsumer, + // Inline this block into the producer + kInlineIntoProducer, +}; + +class AutoInline : public AutoGenRule { + public: + AutoInline(const common::Target& target, const std::unordered_set& no_inline_output_names); + ~AutoInline() = default; + + RuleApplyType Init(ir::IRSchedule* ir_schedule) override; + + void Apply(int index) override; + + std::string GetRuleName() const override; + + AutoInlineType AnalyzeInlineType(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; + + bool CanInlineIntoConsumer(const Expr& sche_block_realize_expr, ir::IRSchedule* ir_sch) const; + + RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + + std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + + private: + void Apply(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); + + private: + std::vector all_block_realizes_; + std::vector> apply_indices_and_type_; + std::unordered_set no_inline_output_names_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc new file mode 100644 index 0000000000000..a8d8ee9f9d0c0 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -0,0 +1,493 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" + +#include +#include + +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" +#include "cinn/cinn.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/ir/function_base.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/poly/stage.h" +#include "cinn/runtime/flags.h" +#include "cinn/utils/string.h" +#include "tests/concrete_program_builder.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::OpLowerer; + +TEST(AutoInline, SingleLoopInline) { + srand(0); + Context::Global().ResetNameId(); + Target target = common::DefaultHostTarget(); + + Expr M(32); + + Placeholder A("A", {M}); + ir::Tensor B = Compute( + {M}, [&](Var i) { return A(i) * ir::Expr(2.f); }, "B"); + ir::Tensor C = Compute( + {M}, [&](Var i) { return B(i) + ir::Expr(1.f); }, "C"); + + poly::StageMap stages = CreateStages({A, B, C}); + std::vector funcs = + lang::LowerVec("TestAutoInline_SingleLoopInline", stages, {A, C}, {}, {}, nullptr, target, true); + + VLOG(6) << "Expr after lowering:"; + VLOG(6) << funcs[0]->body; + + /* + * We have to use ComputeAt to put two Tensor loops together to create IR + * test case for AutoInline. + */ + ir::IRSchedule ir_sch(ir::ModuleExpr(std::vector{funcs[0]->body})); + SearchState state(ir_sch, 0, {}); + ir::Expr block_b = ir_sch.GetBlock("B"); + std::vector loops = ir_sch.GetLoops("C"); + ir_sch.ComputeAt(block_b, loops[0]); + + ir::ModuleExpr mod_expr_before_inline = ir_sch.GetModule(); + VLOG(6) << "Expr after ComputeAt:"; + VLOG(6) << mod_expr_before_inline.GetExprs()[0]; + + AutoInline auto_inline(target, {"C"}); + EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.NumberApplicable(), 1); + auto_inline.ApplyRandomly(); + std::vector exprs = ir_sch.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + + // ApplyOnBlock + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "B"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "B"); + + auto test_func = [](ir::IRSchedule* ir_sch) { + ir::ModuleExpr mod_expr_after_inline = ir_sch->GetModule(); + std::vector exprs = mod_expr_after_inline.GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + + std::stringstream ss; + ss << exprs[0]; + + std::string expr_str = ss.str(); + VLOG(6) << "After AutoInline:"; + VLOG(6) << expr_str; + + std::string target_str = R"ROC({ + ScheduleBlock(root) + { + { + serial for (i, 0, 32) + { + ScheduleBlock(C) + { + i0 = axis.bind(i) + read_buffers(_A[i0(0:32)]) + write_buffers(_C[i0(0:32)]) + C[i0] = ((A[i0] * 2.00000000f) + 1.00000000f) + } + } + } + } +})ROC"; + EXPECT_EQ(expr_str, target_str); + }; + + test_func(&ir_sch); + test_func(&new_states[0]->ir_schedule); + + // Cannot inline above expr again + EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "C"), RuleApplyType::kCannotApply); +} + +TEST(AutoInline, AddReluInline) { + srand(0); + Context::Global().ResetNameId(); + Target target = common::DefaultHostTarget(); + + frontend::NetBuilder builder("test"); + + auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A"); + auto b = builder.CreateInput(Float(32), {64}, "B"); + auto c = builder.Add(a, b, 1); + auto d = builder.Relu(c); + + frontend::Program program = builder.Build(); + + FLAGS_cinn_ir_schedule = true; + auto graph = std::make_shared(program, target); + hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); + + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs>("infershape"); + auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + + EXPECT_EQ(graph->fusion_groups.size(), 1UL); + std::vector funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); + + VLOG(6) << "Expr before auto inline: " << funcs[0]->body; + + ir::ModuleExpr mod_expr_before_inline(std::vector({funcs[0]->body})); + ir::IRSchedule ir_sch(mod_expr_before_inline); + SearchState state(ir_sch, 0, {}); + + AutoInline auto_inline(target, {"var_2"}); + EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.NumberApplicable(), 2); + + auto_inline.Apply(1); + ir::ModuleExpr mod_expr_after_inline = ir_sch.GetModule(); + std::vector exprs = mod_expr_after_inline.GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + + std::stringstream ss; + ss << exprs[0]; + + std::string expr_str = ss.str(); + VLOG(6) << "After AutoInline:"; + VLOG(6) << expr_str; + + // Auto Inline again + EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(auto_inline.NumberApplicable(), 1); + auto_inline.Apply(0); + + // ApplyOnBlock + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_1"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); + // Auto Inline again + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_3"), RuleApplyType::kApplyAndPruneOtherRules); + new_states = auto_inline.ApplyOnBlock(new_states[0], "var_3"); + + auto test_func = [](ir::IRSchedule* ir_sch) { + ir::ModuleExpr final_mod_expr = ir_sch->GetModule(); + auto exprs = final_mod_expr.GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + + std::stringstream ss; + ss << exprs[0]; + + std::string expr_str = ss.str(); + VLOG(6) << "Final AutoInline:"; + VLOG(6) << expr_str; + + std::string target_str = R"ROC({ + ScheduleBlock(root) + { + { + serial for (i, 0, 1) + { + serial for (j, 0, 64) + { + serial for (k, 0, 112) + { + serial for (a, 0, 112) + { + ScheduleBlock(var_2) + { + i0, i1, i2, i3 = axis.bind(0, j, k, a) + read_buffers(_A[i0(0:1), i1(0:64), i2(0:112), i3(0:112)], _B[i1(0:64)]) + write_buffers(_var_2[i0(0:1), i1(0:64), i2(0:112), i3(0:112)]) + var_2[i0, i1, i2, i3] = cinn_max((A[i0, i1, i2, i3] + B[i1]), 0.00000000f) + } + } + } + } + } + } + } +})ROC"; + EXPECT_EQ(expr_str, target_str); + }; + + test_func(&ir_sch); + test_func(&new_states[0]->ir_schedule); + + // Cannot inline above expr again + EXPECT_EQ(auto_inline.Init(&ir_sch), RuleApplyType::kCannotApply); + EXPECT_EQ(auto_inline.AnalyseApplyType(new_states[0], "var_2"), RuleApplyType::kCannotApply); +} + +#ifdef CINN_WITH_CUDA +class TestAutoInline : public TestAutoGenRuleBase {}; + +/* The single chain graph composed of multiple blocks can be inlined into one. + * + * Before AutoInline: The output of the previous block is the input of another block. + * Loop1: + * x1 = Add() + * Loop2: + * x2 = Multiply(x1) + * Loop3: + * x3 = Add(x2) + * Loop4: + * x4 = Relu(x3) + * + * After AutoInline: All loops are inlined into a loop. + * Loop: + * Add(Multiply(Add(Relu()))) + */ +TEST_F(TestAutoInline, SingleChain) { + Target target = common::DefaultNVGPUTarget(); + Initialize(target); + std::vector input_names = {"bias", "conv_output", "bn_scale", "bn_offset"}; + std::vector output_names = {"var_6", "var_5", "var_1", "var", "var_0", "var_4", "var_3"}; + std::vector conv_output_shape = {1, 512, 56, 56}; + int32_t channel = conv_output_shape[1]; + std::vector inputs_varinfo({{"conv_output", conv_output_shape}, + {"bias", {channel, 1, 1}}, + {"bn_scale", {channel, 1, 1}}, + {"bn_offset", {channel, 1, 1}}}); + + // Construct the computation graph and convert it to ir::Expr + Context::Global().ResetNameId(); + ir::IRSchedule ir_schedule = MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo)); + SearchState state(ir_schedule, 0, {}); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // Apply AutoInline for every block that can be inline + AutoInline auto_inline(target_, {output_names.front()}); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_3"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "var_3"); + std::vector inline_block_names({"var_4", "var_5", "var_6", "var", "var_0", "var_1"}); + for (const auto& inline_block_name : inline_block_names) { + new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); + } + std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; + + // build ir::Module and debug source code + auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); + auto build_module_manually = + BuildIRModule(MakeIRSchedule(tests::BiasBnReLUBuilder().Build(inputs_varinfo), -1, true)); + auto source_code_auto = GenSourceCode(build_module_auto); + VLOG(6) << " auto-schedule source code:\n" << source_code_auto; + auto source_code_manually = GenSourceCode(build_module_manually); + VLOG(6) << " manually-schedule source code:\n" << source_code_manually; + + CheckResult(GenExecutableKernel(build_module_auto), + GenExecutableKernel(build_module_manually), + input_names, + output_names, + {{conv_output_shape[1], 1, 1}, conv_output_shape, conv_output_shape, conv_output_shape}, + {conv_output_shape, {1}, {1}, {1}, {1}, {1}, {1}}, + target); +} + +/* An op can be inlined into multiple consumers at the same time. + * + * Before AutoInline: The output of Exp is used by Add and Multiply. + * Loop1: + * x = Exp() + * Loop2: + * y = Add(x) + * Loop3: + * z = Multiply(x) + * + * After AutoInline: Exp is inlined into Add and Multiply. + * Loop: + * y = Add(Exp()) + * z = Multiply(Exp()) + */ +TEST_F(TestAutoInline, InlineToMultiConsumers) { + Target target = common::DefaultNVGPUTarget(); + Initialize(target); + std::vector input_names = {"x"}; + std::vector output_names = {"var_2", "var_1", "var_0"}; + std::vector input_shape{256, 256}; + std::vector inputs_varinfo({{"x", input_shape}}); + + // Construct the computation graph and convert it to ir::Expr + Context::Global().ResetNameId(); + ir::IRSchedule ir_schedule = MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo)); + SearchState state(ir_schedule, 0, {}); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // Apply AutoInline for every block that can be inline + AutoInline auto_inline(target_, {output_names.front()}); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "var_0"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "var_1"); + new_states = auto_inline.ApplyOnBlock(state, "var_0"); + std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; + + // build ir::Module and debug source code + auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); + auto build_module_manually = + BuildIRModule(MakeIRSchedule(tests::ExpTwoConsumersOpBuilder().Build(inputs_varinfo), -1, true)); + auto source_code_auto = GenSourceCode(build_module_auto); + VLOG(6) << " auto-schedule source code:\n" << source_code_auto; + auto source_code_manually = GenSourceCode(build_module_manually); + VLOG(6) << " manually-schedule source code:\n" << source_code_manually; + + CheckResult(GenExecutableKernel(build_module_auto), + GenExecutableKernel(build_module_manually), + input_names, + output_names, + {input_shape}, + {input_shape, {1}, {1}}, + target); +} + +/* Operators of type elementwise or injective can all be inlined. + * + * Before AutoInline: A graph of Gather, Add and Subtract + * Loop1: + * x1 = Gather() + * Loop2: + * x2 = Add(x1) + * Loop3: + * y1 = Gather() + * Loop4: + * z1 = Subtract(y1, x1) + * + * After AutoInline: All loops are inlined to one + * z1 = Subtract(Gather(), Add(Gather())) + */ +TEST_F(TestAutoInline, OnlySpatialOp) { + Target target = common::DefaultNVGPUTarget(); + Initialize(target); + std::vector input_names = {"x", "y"}; + std::vector output_names = { + "var_6", "var_4", "constant_idx_last", "constant_idx_first", "var_2", "var_5"}; + std::vector input_shape{256, 256}; + std::vector inputs_varinfo({{"x", input_shape}, {"y", input_shape}}); + + // Construct the computation graph and convert it to ir::Expr + Context::Global().ResetNameId(); + ir::IRSchedule ir_schedule = MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo)); + SearchState state(ir_schedule, 0, {}); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // Apply AutoInline for every block that can be inline + AutoInline auto_inline(target_, {output_names.front()}); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "constant_idx_first"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "constant_idx_first"); + std::vector inline_block_names({"constant_idx_last", "var_2", "var_5", "var_4"}); + for (const auto& inline_block_name : inline_block_names) { + new_states = auto_inline.ApplyOnBlock(new_states[0], inline_block_name); + } + std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; + + // build ir::Module and debug source code + auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); + auto build_module_manually = + BuildIRModule(MakeIRSchedule(tests::GatherAddSubBuilder().Build(inputs_varinfo), -1, true)); + auto source_code_auto = GenSourceCode(build_module_auto); + VLOG(6) << " auto-schedule source code:\n" << source_code_auto; + auto source_code_manually = GenSourceCode(build_module_manually); + VLOG(6) << " manually-schedule source code:\n" << source_code_manually; + + CheckResult(GenExecutableKernel(build_module_auto), + GenExecutableKernel(build_module_manually), + input_names, + output_names, + {input_shape, input_shape}, + {input_shape, {1}, {1}, {1}, {1}, {1}}, + target); +} + +/* An op that does not read data can be directly inlined. + * + * Before AutoInline: fill_constant op is in a separate loop. + * Loop1: + * x = fill_constant() + * Loop2: + * y = Add(x) + * + * After AutoInline: fill_constant op is inlined into other loop + * Loop: + * y = Add(fill_constant()) + */ +TEST_F(TestAutoInline, NoReadBufferOp) { + Target target = common::DefaultNVGPUTarget(); + Initialize(target); + std::vector input_names = {"x"}; + std::vector output_names = {"var_0", "fill_constant"}; + std::vector input_shape{256, 256}; + std::vector inputs_varinfo({{"x", input_shape}}); + + // Construct the computation graph and convert it to ir::Expr + ir::IRSchedule ir_schedule = MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo)); + SearchState state(ir_schedule, 0, {}); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // Apply AutoInline for every block that can be inline + AutoInline auto_inline(target_, {output_names.front()}); + EXPECT_EQ(auto_inline.AnalyseApplyType(state, "fill_constant"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = auto_inline.ApplyOnBlock(state, "fill_constant"); + std::vector exprs = new_states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + VLOG(6) << "Expr after AutoInline applied on block: " << exprs[0]; + + // build ir::Module and debug source code + auto build_module_auto = BuildIRModule(new_states[0]->ir_schedule); + auto build_module_manually = + BuildIRModule(MakeIRSchedule(tests::FillConstantAddBuilder().Build(inputs_varinfo), -1, true)); + auto source_code_auto = GenSourceCode(build_module_auto); + VLOG(6) << " auto-schedule source code:\n" << source_code_auto; + auto source_code_manually = GenSourceCode(build_module_manually); + VLOG(6) << " manually-schedule source code:\n" << source_code_manually; + + CheckResult(GenExecutableKernel(build_module_auto), + GenExecutableKernel(build_module_manually), + input_names, + output_names, + {input_shape}, + {input_shape, {1}}, + target); +} + +/* An op can be inlined into multiple producers at the same time. + */ +// TEST_F(TestAutoInline, InlineToMultiProducers) { +// TODO(6clc): Complete the unit test, once ReverseComputeInline is ready. +// } +#endif +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc new file mode 100644 index 0000000000000..a4bc75ef1af83 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h" + +#include + +#include + +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +static std::vector auto_unroll_options = {0, 8, 32, 128}; + +bool AutoUnroll::MeetCondition(const ir::ScheduleBlock* schedule_block) const { + // whether any block has reduce iter + auto has_reduce_iter = [](const Expr* x) { + auto* block_realize = x->As(); + if (block_realize) { + auto* schedule_block = block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock"; + for (auto&& var : schedule_block->iter_vars) { + if (var->is_reduce_axis) { + VLOG(6) << "find ScheduleBlockRealize:" << *x << " has reduce_axis:" << var; + return true; + } + } + } + return false; + }; + // whether has any for-loop with non-serial type + auto has_nonserial_loop = [](const Expr* x) { + if (x->As() && x->As()->for_type() != ir::ForType::Serial) { + VLOG(6) << "find non-serial loop:" << *x; + return true; + } + return false; + }; + + auto find_target_exprs = ir::CollectIRNodesWithoutTensor( + schedule_block->body, + [&has_reduce_iter, &has_nonserial_loop](const Expr* x) { return has_reduce_iter(x) || has_nonserial_loop(x); }); + + return !find_target_exprs.empty(); +} + +RuleApplyType AutoUnroll::Init(ir::IRSchedule* ir_schedule) { + ir_schedule_ = ir_schedule; + auto block_realizes = ir_schedule_->GetAllBlocks(); + + // A schedule block can perform `auto_unroll` rule should meet two conditions: + // (1) it is a root block + // (2) MeetCondition returns true with it + applicable_schedule_blocks_.clear(); + std::set deduplicate_results; + for (size_t i = 0; i < block_realizes.size(); ++i) { + // find root block + Expr root_block = ir_schedule_->GetRootBlock(block_realizes[i]); + auto* block_realize = root_block.As(); + CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; + auto* schedule_block = block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); + if (MeetCondition(schedule_block)) { + deduplicate_results.emplace(root_block); + } + } + applicable_schedule_blocks_ = {deduplicate_results.begin(), deduplicate_results.end()}; + num_applicable_ = applicable_schedule_blocks_.size(); + VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_; + + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +void AutoUnroll::Apply(int index) { + CHECK_LT(index, applicable_schedule_blocks_.size()) << "invalid apply index:" << index; + auto applied_block = applicable_schedule_blocks_.at(index); + int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; + ir_schedule_->Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); + return; +} + +RuleApplyType AutoUnroll::AnalyseApplyType(SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); + Expr root_block = state->ir_schedule.GetRootBlock(block_expr); + auto* block_realize = root_block.As(); + CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << root_block; + auto* schedule_block = block_realize->schedule_block.As(); + CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock:" << Expr(block_realize); + + return MeetCondition(schedule_block) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +std::vector AutoUnroll::ApplyOnBlock(SearchState state, const std::string& block_name) { + SearchState new_state = state.Copy(); + Expr block_expr = new_state->ir_schedule.GetBlock(block_name); + Expr applied_block = new_state->ir_schedule.GetRootBlock(block_expr); + int max_step = auto_unroll_options[std::rand() % auto_unroll_options.size()]; + new_state->ir_schedule.Annotate(applied_block, ir::attr::auto_unroll_max_step, max_step); + + return {new_state}; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h new file mode 100644 index 0000000000000..f1b67d173cf3f --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h @@ -0,0 +1,54 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +// This rule can be applied in a ScheduleBlock has reduce axis or has loops with non-serial type. +// As a result, it will set a attribute with key named ir::attr::auto_unroll_max_step and value +// indicating max permitted unrolled step in the applied ScheduleBlock. Finally, UnrollLoop pass +// will do unroll based on actual situation. +class AutoUnroll : public AutoGenRule { + public: + AutoUnroll(const common::Target& target) : AutoGenRule(target) {} + ~AutoUnroll() = default; + + RuleApplyType Init(ir::IRSchedule* init_schedule) override; + + void Apply(int index) override; + + std::string GetRuleName() const override { return "AutoUnroll"; } + + RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + + std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + + private: + bool MeetCondition(const ir::ScheduleBlock* schedule_block) const; + + private: + std::vector applicable_schedule_blocks_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc new file mode 100644 index 0000000000000..99688a2da6738 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h" + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/lang/lower.h" + +namespace cinn { +namespace auto_schedule { + +TEST(AutoUnroll, Init) { + using namespace ir; + + Expr M(100); + Expr N(4); + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C"); + +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + auto stages = CreateStages({C}); + auto funcs = cinn::lang::LowerVec("test_init", stages, {A, B, C}, {}, {}, nullptr, target, true); + + auto ast_expr = funcs[0]->body; + ir::IRSchedule init_schedule(ir::ModuleExpr({ast_expr})); + AutoUnroll test_rule(target); + // not meet specific condition + ASSERT_EQ(test_rule.Init(&init_schedule), RuleApplyType::kCannotApply); +} + +TEST(AutoUnroll, UnrollableApply) { + using namespace ir; + + Expr M(100); + Expr N(4); + Expr K(32); + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + Var k(K.as_int32(), "k0"); + Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + auto stages = CreateStages({C}); + auto funcs = cinn::lang::LowerVec("test_unrollable", stages, {A, B, C}, {}, {}, nullptr, target, true); + + auto ast_expr = funcs[0]->body; + auto* init_block_realize = ast_expr.As()->stmts.front().As(); + auto* init_schedule_block = init_block_realize->schedule_block.As(); + ASSERT_NE(init_schedule_block, nullptr); + ASSERT_TRUE(init_schedule_block->attrs.empty()); + VLOG(6) << "Before auto-unroll:\n" << ast_expr; + + AutoUnroll test_rule(target); + ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); + SearchState state(ir_schedule, 0, {}); + ASSERT_EQ(test_rule.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(test_rule.NumberApplicable(), 1); + test_rule.ApplyRandomly(); + + // ApplyOnBlock + EXPECT_EQ(test_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); + std::vector states = test_rule.ApplyOnBlock(state, "C"); + + auto test_func = [](IRSchedule* ir_sch) { + Expr applied_expr = ir_sch->GetModule().GetExprs().front(); + auto* applied_block_realize = applied_expr.As()->stmts.front().As(); + auto* applied_schedule_block = applied_block_realize->schedule_block.As(); + ASSERT_FALSE(applied_schedule_block->attrs.empty()); + EXPECT_EQ(applied_schedule_block->attrs.count(ir::attr::auto_unroll_max_step), 1); + const auto& attr_value = applied_schedule_block->attrs.at(ir::attr::auto_unroll_max_step); + const int* max_step = absl::get_if(&attr_value); + EXPECT_NE(max_step, nullptr); + EXPECT_LE(*max_step, 128); + VLOG(6) << "After auto-unroll:max_step=" << *max_step << ", Ast:\n" << ir_sch->GetModule().GetExprs().front(); + }; + + test_func(&ir_schedule); + test_func(&states[0]->ir_schedule); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc new file mode 100644 index 0000000000000..21ed0e94f9ddf --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "tests/program_builder.h" + +namespace cinn { +namespace auto_schedule { + +class TestMixRules : public TestAutoGenRuleBase { + public: + std::vector default_input_names = {"X", "Y"}; + std::vector default_output_names = {"temp_matmul_out"}; +}; + +TEST_F(TestMixRules, 2DMatmulOnMultiTilingRelated) { + frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}); + Initialize(common::DefaultNVGPUTarget()); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op); + std::vector func_bodys = ir_schedule.GetModule().GetExprs(); + ASSERT_EQ(func_bodys.size(), 1UL); + VLOG(6) << "Original Expr:\n" << func_bodys[0]; + + // Apply MultiLevelTiling + MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); + multi_level_tiling.Init(&ir_schedule); + ASSERT_EQ(multi_level_tiling.NumberApplicable(), 1); + multi_level_tiling.ApplyRandomly(); + VLOG(6) << "after MultiLevelTiling Expr:\n" << func_bodys[0]; + + // build ir::Module and debug source code + auto ir_module = BuildIRModule(ir_schedule); + auto source_code = GenSourceCode(ir_module); + VLOG(6) << "scheduled source code:\n" << source_code; + // execute and check precision + CheckResult(GenExecutableKernel(ir_module), + GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, /* apply_manual_schedule */ true))), + default_input_names, + default_output_names, + {{32, 32}, {32, 32}}, + {{32, 32}}, + target_); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc new file mode 100644 index 0000000000000..3dee778f8f886 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc @@ -0,0 +1,401 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/auto_schedule/analysis/analyze_ir.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +MultiLevelTiling::MultiLevelTiling(const common::Target& target, const Config& config) + : AutoGenRule(target), config_(config) { + for (int i = 0; i < config_.tile_struct.size(); ++i) { + if (config_.tile_struct[i] == 'S') { + s_indices_.push_back(i); + } else if (config_.tile_struct[i] == 'R') { + r_indices_.push_back(i); + } else { + CHECK(false) << "Illegal tiling structure string"; + } + } +} + +bool MultiLevelTiling::MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const { + return NeedsMultiLevelTiling(sche_block_realize); +} + +RuleApplyType MultiLevelTiling::Init(ir::IRSchedule* ir_schedule) { + ir_schedule_ = ir_schedule; + all_block_realizes_ = ir_schedule_->GetAllBlocks(); + applicable_indices_.clear(); + num_applicable_ = 0; + for (size_t i = 0; i < all_block_realizes_.size(); ++i) { + ir::ScheduleBlockRealize* sche_block_realize = all_block_realizes_[i].As(); + AnalyzeScheduleBlockReadWriteBuffer(sche_block_realize->schedule_block.As()); + if (MeetCondition(*sche_block_realize)) { + ++num_applicable_; + applicable_indices_.push_back(i); + } + } + + return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +void MultiLevelTiling::Apply(int index) { + CHECK(ir_schedule_ != nullptr) << "Run MultiLevelTiling::Apply without Init"; + CHECK(num_applicable_ > 0 && applicable_indices_.size() == num_applicable_) + << "MultiLevelTiling::Apply pre-condition doesn't meet"; + CHECK(index >= 0 && num_applicable_ > index) + << "Invalid index for MultiLevelTiling::Apply, the index needs 0 <= index && index < NumberApplicable(), " + << "Currently index = " << index << ", NumberApplicable() = " << num_applicable_; + + int apply_index = applicable_indices_[index]; + std::string block_name = + all_block_realizes_[apply_index].As()->schedule_block.As()->name; + Expr block_expr = all_block_realizes_[apply_index]; + ApplyTiling(ir_schedule_, block_expr); + block_expr = ir_schedule_->GetBlock(block_name); + ApplyCacheRead(ir_schedule_, block_expr); + block_expr = ir_schedule_->GetBlock(block_name); + ApplyCacheWrite(ir_schedule_, block_expr); + + VLOG(4) << "Returning the result of MultiLevelTiling"; + return; +} + +std::string MultiLevelTiling::GetRuleName() const { return "MultiLevelTiling"; } + +RuleApplyType MultiLevelTiling::AnalyseApplyType(SearchState state, const std::string& block_name) const { + Expr block_expr = state->ir_schedule.GetBlock(block_name); + auto* block_realize = block_expr.As(); + CHECK(block_realize) << "stmt is not a ScheduleBlockRealize:" << block_expr; + AnalyzeScheduleBlockReadWriteBuffer(block_realize->schedule_block.As()); + + return NeedsMultiLevelTiling(*block_realize) ? RuleApplyType::kApplyAndPruneOtherRules : RuleApplyType::kCannotApply; +} + +std::vector MultiLevelTiling::ApplyOnBlock(SearchState state, const std::string& block_name) { + SearchState new_state = state.Copy(); + ir::IRSchedule* ir_sch = &new_state->ir_schedule; + Expr block_expr = ir_sch->GetBlock(block_name); + ApplyTiling(ir_sch, block_expr); + block_expr = ir_sch->GetBlock(block_name); + ApplyCacheRead(ir_sch, block_expr); + block_expr = ir_sch->GetBlock(block_name); + ApplyCacheWrite(ir_sch, block_expr); + + VLOG(4) << "Returning the result of MultiLevelTiling"; + return {new_state}; +} + +void MultiLevelTiling::ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { + ir::ScheduleBlockRealize* sche_block_realize = block_expr.As(); + ir::ScheduleBlock* sche_block = sche_block_realize->schedule_block.As(); + tile_loops_.clear(); + tile_loops_.resize(config_.tile_struct.size()); + std::vector for_exprs = ir_schedule->GetLoops(block_expr); + + VLOG(5) << "The number of loops to split in MultiLevelTiling is " << for_exprs.size(); + for (int i = for_exprs.size() - 1; i >= 0; --i) { + ir::For* ir_for = for_exprs[i].As(); + VLOG(6) << "Applying Split for MultiLevelTiling on: " << Expr(ir_for); + const std::vector* idx = nullptr; + if (sche_block->iter_vars[i]->is_reduce_axis) { + idx = &r_indices_; + } else { + idx = &s_indices_; + } // TODO: support more iterator variable types + + int extent = ir_for->extent.as_int32(); // maybe int64? + + int num_split = idx->size(); + if (num_split > 1) { + std::vector tile_split_factor = ir_schedule->SamplePerfectTile(Expr(ir_for), num_split, 64); + std::vector splited = ir_schedule->Split(Expr(ir_for), tile_split_factor); + VLOG(6) << "Finish Split for MultiLevelTiling on above loop"; + for (int j = 0; j < num_split; ++j) { + tile_loops_[idx->at(j)].push_back(splited[j]); + } + } else { + tile_loops_[idx->at(0)].push_back(for_exprs[i]); + } + } + VLOG(5) << "Finish Split in MultiLevelTiling, before Reorder."; + + // Have to GetLoops again because Split can change Block Expr(s) + for_exprs = ir_schedule->GetLoops(sche_block->name); + std::unordered_map loop_var_name_to_idx; + for (int i = 0; i < for_exprs.size(); ++i) { + loop_var_name_to_idx[for_exprs[i].As()->loop_var->name] = i; + } + CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names after split"; + + std::vector splited_loops; + for (auto& t : tile_loops_) { + std::reverse(t.begin(), t.end()); + for (auto& tile_loop_expr : t) { + const ir::For* tile_loop = tile_loop_expr.As(); + CHECK(tile_loop) << "tiles store non For Expr"; + int idx = loop_var_name_to_idx[tile_loop->loop_var->name]; + splited_loops.push_back(for_exprs[idx]); + } + } + + Expr reordered_expr = ir_schedule->Reorder(splited_loops); + VLOG(5) << "Finish Reorder in MultiLevelTiling, now do Fuse and Binding on the main loop chain"; + + int num_binds = std::min(config_.bind_axis.size(), tile_loops_.size()); + for (int i = 0; i < num_binds; ++i) { + loop_var_name_to_idx.clear(); + for_exprs = ir_schedule->GetLoops(sche_block->name); + for (int j = 0; j < for_exprs.size(); ++j) { + loop_var_name_to_idx[for_exprs[j].As()->loop_var->name] = j; + } + CHECK(loop_var_name_to_idx.size() == for_exprs.size()) << "Loops contain duplicate loop var names before Fusion"; + + // Some loops extent may exceed the limited max factor (For example, + // exceed the limit number of CUDA threads), here we check whether + // the fused loop extent, which is the production of extends of loops + // to be fused, is less or equal to the max factor. + // + // If yes, we fuse those loops and bind the fused loop + // If no, we bind the first loop whose extent is less than the factor. + int extent_prod = 1; + int first_idx_less_than_max_factor = -1; + for (int j = 0; j < tile_loops_[i].size(); ++j) { + const ir::For* tile_loop = tile_loops_[i][j].As(); + CHECK(tile_loop) << "tiles store non For Expr"; + int idx = loop_var_name_to_idx[tile_loop->loop_var->name]; + tile_loops_[i][j] = for_exprs[idx]; + int extent = tile_loop->extent.as_int32(); // maybe int64? + extent_prod *= extent; + if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) { + first_idx_less_than_max_factor = idx; + } + } + + if (extent_prod <= max_factor_) { + Expr fused = ir_schedule->Fuse(tile_loops_[i]); + ir_schedule->Bind(fused, config_.bind_axis[i]); + } else if (first_idx_less_than_max_factor != -1) { + ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); + } + } + + VLOG(5) << "Do Fuse and Binding on the non-main loop chains"; + Expr sche_block_top_loop = ir_schedule->GetLoops(sche_block->name)[0]; + + if (reordered_expr.As()) { + for (Expr& top_loop : reordered_expr.As()->stmts) { + if (top_loop != sche_block_top_loop) { + std::vector scan_loop_blocks = ir_schedule->GetAllBlocks(); + Expr other_loop_chain_schedule; + for (Expr& block : scan_loop_blocks) { + std::vector loop_chain = ir_schedule->GetLoops(block); + if (loop_chain[0] == top_loop) { + other_loop_chain_schedule = block; + break; + } + } + if (!other_loop_chain_schedule.defined()) { + LOG(WARNING) << "Has non-main loop chain, but not corresponding ScheduleBlock in MultiLevelTiling"; + continue; + } + + std::string other_loop_schedule_name = + other_loop_chain_schedule.As()->schedule_block.As()->name; + VLOG(6) << "Found other_loop_schedule_name = " << other_loop_schedule_name; + int fuse_index = 0; + for (int i = 0; i < num_binds; ++i) { + for_exprs = ir_schedule->GetLoops(other_loop_schedule_name); + + // Some loops extent may exceed the limited max factor (For example, + // exceed the limit number of CUDA threads), here we check whether + // the fused loop extent, which is the production of extends of loops + // to be fused, is less or equal to the max factor. + // + // If yes, we fuse those loops and bind the fused loop + // If no, we bind the first loop whose extent is less than the factor. + int extent_prod = 1; + int first_idx_less_than_max_factor = -1; + for (int j = 0; j < tile_loops_[i].size(); ++j) { + int extent = for_exprs[fuse_index + j].As()->extent.as_int32(); + extent_prod *= extent; + if (first_idx_less_than_max_factor == -1 && extent <= max_factor_) { + first_idx_less_than_max_factor = fuse_index + j; + } + } + if (extent_prod <= max_factor_) { + std::vector loops_to_fuse(for_exprs.begin() + fuse_index, + for_exprs.begin() + fuse_index + tile_loops_[i].size()); + Expr fused = ir_schedule->Fuse(loops_to_fuse); + ir_schedule->Bind(fused, config_.bind_axis[i]); + fuse_index += 1; + } else if (first_idx_less_than_max_factor != -1) { + ir_schedule->Bind(for_exprs[first_idx_less_than_max_factor], config_.bind_axis[i]); + fuse_index += tile_loops_[i].size(); + } + } + } + } + } +} + +void MultiLevelTiling::ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { + ir::ScheduleBlockRealize* sch_block_realize = block_expr.As(); + ir::ScheduleBlock* sch_block = sch_block_realize->schedule_block.As(); + std::string block_name = sch_block->name; + + // Analyze which buffers can be cached + std::vector read_buffer_indexes; + for (int i = 0; i < sch_block->read_buffers.size(); ++i) { + bool is_read_write = false; + for (int j = 0; j < sch_block->write_buffers.size(); ++j) { + if (sch_block->read_buffers[i] == sch_block->write_buffers[j]) { + is_read_write = true; + break; + } + } + if (!is_read_write) { + read_buffer_indexes.push_back(i); + } + } + + // Schedule + for (int read_buffer_index : read_buffer_indexes) { + for (int level : config_.read_cache_levels) { + // 1.find target loop + const auto loops = tile_loops_.at(level - 1); + if (loops.size() == 0) { + continue; + } + + // 2.Do CacheRead and get the cache block + ir::Expr cache_block = ir_schedule->CacheRead(block_expr, read_buffer_index, config_.read_cache_memory_type); + std::string cache_block_name = + cache_block.As()->schedule_block.As()->name; + + std::string target_for_loop_name = loops.back().As()->loop_var->name; + + // 3.Place the cache_block under target_for_loop + // The original block expr is invalid after the CacheRead schedule, + // so we reacquire the block expr after the schedule according to the block name + block_expr = ir_schedule->GetBlock(block_name); + std::vector for_exprs = ir_schedule->GetLoops(block_expr); + for (const Expr& for_expr : for_exprs) { + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos) { + ir_schedule->ComputeAt(cache_block, for_expr, true); + break; + } + } + + // 4.Threads under the same block cooperative fetch data from global memory. + Expr new_cache_block = ir_schedule->GetBlock(cache_block_name); + auto cache_block_loops = ir_schedule->GetLoops(new_cache_block); + std::vector compute_at_extra_var = utils::Split( + absl::get( + new_cache_block.As()->schedule_block.As()->attrs.at( + "compute_at_extra_var")), + ","); + std::vector buffer_loops; + // int nthreads = 1; + for (const Expr& for_expr : cache_block_loops) { + if (std::find(compute_at_extra_var.begin(), + compute_at_extra_var.end(), + for_expr.As()->loop_var->name) != compute_at_extra_var.end()) { + buffer_loops.push_back(for_expr); + } + } + auto fused_buffer_loop = ir_schedule->Fuse(buffer_loops); + // TODO(BiynXu): Implement vectorize fetching data and pass in vector length + ir_schedule->Annotate(ir_schedule->GetBlock(cache_block_name), ir::attr::cooperative_process, 0); + } + } +} + +void MultiLevelTiling::ApplyCacheWrite(ir::IRSchedule* ir_schedule, ir::Expr& block_expr) { + ir::Expr cache_block = ir_schedule->CacheWrite(block_expr, 0, config_.write_cache_memory_type); + + for (int level : config_.write_cache_levels) { + const auto loops = tile_loops_.at(level - 1); + if (loops.size() == 0) { + continue; + } + std::string target_for_loop_name = loops.back().As()->loop_var->name; + // Because the block name is changed in CacheWrite, we need to calculate the derived name + // according to the logic of CacheWrite and find the loop structure according to the derived name. + const std::string original_block_name = + block_expr.As()->schedule_block.As()->name; + const std::string derivative_block_name = + original_block_name + "_" + config_.write_cache_memory_type + "_temp_buffer"; + std::vector for_exprs = ir_schedule->GetLoops(derivative_block_name); + for (const Expr& for_expr : for_exprs) { + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos) { + ir_schedule->ReverseComputeAt(ir_schedule->GetBlock(original_block_name), for_expr, true); + } + } + + const std::string reduce_init_block_name = original_block_name + "__reduce_init"; + for_exprs = ir_schedule->GetLoops(derivative_block_name); + for (const Expr& for_expr : for_exprs) { + if (for_expr.As()->loop_var->name.find(target_for_loop_name) != std::string::npos && + ir_schedule->HasBlock(reduce_init_block_name)) { + ir_schedule->SimpleComputeAt(ir_schedule->GetBlock(reduce_init_block_name), for_expr); + } + } + } +} + +const std::unordered_map MultiLevelTiling::kConfigs{ + {common::Target::Arch::NVGPU, + MultiLevelTiling::Config{ + /*bind_axis*/ std::vector{"blockIdx.x", "threadIdx.x"}, + /*tile_struct*/ std::string("SSSRRSRS"), + /*read_cache_memory_type*/ std::string("shared"), + /*read_cache_levels*/ std::vector{4}, + /*write_cache_memory_type*/ std::string("local"), + /*write_cache_levels*/ std::vector{3}, + }}, + {common::Target::Arch::X86, + MultiLevelTiling::Config{ + /*bind_axis*/ std::vector{}, + /*tile_struct*/ std::string("SSRSRS"), + /*read_cache_memory_type*/ std::string("local"), + /*read_cache_levels*/ std::vector{3}, + /*write_cache_memory_type*/ std::string("local"), + /*write_cache_levels*/ std::vector{2}, + }}}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h new file mode 100644 index 0000000000000..0756071657dbd --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h @@ -0,0 +1,138 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +class MultiLevelTiling : public AutoGenRule { + public: + struct Config { + // Which thread axis each tiled loop is bound to + std::vector bind_axis; + // Use char 'S' and 'R' to represent tile structure. + // S means space tiling level and R means reduce tiling level + // + // For example, if tile_struct_ = "SSRSRS" and we are doing matrix + // multiplication, i, j are the spatial indices and k is the reduce index, + // the tiling result will be i_0, j0, i1, j1, k0, i2, j2, k1, i3, j3 + std::string tile_struct; + // The storage type of read cache + std::string read_cache_memory_type; + // Which tiled levels are read cache block inserted at + std::vector read_cache_levels; + // The storage type of write cache + std::string write_cache_memory_type; + // Which tiled levels are write cache block inserted at + std::vector write_cache_levels; + }; + + static const std::unordered_map kConfigs; + + MultiLevelTiling(const common::Target& target, const Config& config); + ~MultiLevelTiling() = default; + + // initialize the AutoGenRule, it must be called before further actions. + // Returns false if the rule cannot be applied on the mod_expr, true otherwise + RuleApplyType Init(ir::IRSchedule* init_schedule) override; + + // Applies rule on the ir::ModuleExpr for a schedule block specified by index + // between 0 (inclusive) and NumberApplicable() (exclusive) + void Apply(int index) override; + + // Returns the name of the rule, used for debug. + std::string GetRuleName() const override; + + // Returns true if sche_block_realize is applicable by MultiLevelTiling + bool MeetCondition(const ir::ScheduleBlockRealize& sche_block_realize) const; + + RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override; + + std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override; + + // Sample pair of integer type (a, b) such as a * b = extent + template + std::vector SampleSplitTwo(T extent) const { + std::vector> candidates; + for (T div = 1; div <= sqrt(extent); ++div) { + if (extent % div == 0) { + candidates.push_back({T(div), extent / div}); + } + } + if (candidates.size() == 0) { + return {1, T(extent)}; + } + int index = rand() % candidates.size(); + std::vector pick = candidates[index]; + if (rand() % 2 != 0) { + T tmp = pick[0]; + pick[0] = pick[1]; + pick[1] = tmp; + } + return pick; + } + + // Sample num_split integers whose product equals extent + template + std::vector SampleTileSplit(T extent, int num_split) const { + CHECK_GT(num_split, 0) << "num_split in SampleTileSplit must be greater than 0"; + if (num_split == 1) { + return {extent}; + } + std::vector two_split = SampleSplitTwo(extent); + if (num_split == 2) { + return two_split; + } + int half = num_split >> 1; + std::vector result = SampleTileSplit(two_split[0], half); + std::vector remind = SampleTileSplit(two_split[1], num_split - half); + result.insert(result.end(), remind.begin(), remind.end()); + return result; + } + + private: + void ApplyTiling(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); + void ApplyCacheRead(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); + void ApplyCacheWrite(ir::IRSchedule* ir_schedule, ir::Expr& block_expr); + + private: + std::vector all_block_realizes_; + std::vector applicable_indices_; + + Config config_; + std::vector s_indices_; + std::vector r_indices_; + std::vector> tile_loops_; + + // A factor to limit the split factor within max thread number per block + int max_factor_ = 1024; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc new file mode 100644 index 0000000000000..91ddf361da4d3 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -0,0 +1,548 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" + +#include +#include + +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" +#include "cinn/cinn.h" +#include "cinn/frontend/syntax.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/poly/stage.h" +#include "cinn/utils/string.h" +#include "tests/program_builder.h" + +namespace cinn { +namespace auto_schedule { + +TEST(MultiLevelTile, SampleSplitTwo) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + + for (int i = 0; i < 100; ++i) { + size_t number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] + std::vector split = multi_level_tiling.SampleSplitTwo(number_to_split); + EXPECT_EQ(split.size(), 2UL); + EXPECT_EQ(split[0] * split[1], number_to_split); + } +} + +TEST(MultiLevelTile, SampleTileSplit) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + + for (int i = 0; i < 100; ++i) { + int number_to_split = rand() % 65535 + 2; // random number in [2, 2^16] + int split_size = rand() % 5 + 1; // random in [1, 5] + std::vector split = multi_level_tiling.SampleTileSplit(number_to_split, split_size); + EXPECT_EQ(split.size(), static_cast(split_size)); + int product = 1; + for (int num : split) { + product *= num; + } + EXPECT_EQ(product, number_to_split); + } +} + +TEST(MultiLevelTile, SimpleLoops) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + Expr M(32); + Expr N(128); + + Placeholder A("A", {M}); + Placeholder B("B", {N}); + + ir::Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); + + poly::StageMap stages = CreateStages({C}); + std::vector funcs = + lang::LowerVec("TestMultiLevelTile_SimpleLoops", stages, {C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr before MultiLevelTiling: "; + VLOG(6) << ast_expr; + + MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); + SearchState state(ir_schedule, 0, {}); + EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); + multi_level_tiling.ApplyRandomly(); + + // ApplyOnBlock + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); + + auto test_func = [](ir::IRSchedule* ir_sch) { + std::vector exprs = ir_sch->GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + std::stringstream ss; + ss << exprs[0]; + std::string expr_str = ss.str(); + VLOG(6) << expr_str; + }; + + test_func(&ir_schedule); + test_func(&new_states[0]->ir_schedule); +} + +// TODO: fix in future +/* +TEST(MulitLevelTile, MatrixMultiply) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + Expr M(32); + Expr N(32); + Expr K(32); + + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + Var k(K.as_int32(), "reduce_axis_k"); + ir::Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + poly::StageMap stages = CreateStages({C}); + std::vector funcs = + lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr before MultiLevelTiling: "; + VLOG(6) << ast_expr; + + MultiLevelTiling multi_level_tiling(target, MultiLevelTiling::kConfigs.at(target.arch)); + ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); + SearchState state(ir_schedule, 0, {}); + EXPECT_EQ(multi_level_tiling.Init(&ir_schedule), RuleApplyType::kApplyAndPruneOtherRules); + EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1); + multi_level_tiling.ApplyRandomly(); + + // ApplyOnBlock + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"), RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = multi_level_tiling.ApplyOnBlock(state, "C"); + + auto test_func = [](ir::IRSchedule* ir_sch) { + std::vector exprs = ir_sch->GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + std::stringstream ss; + ss << exprs[0]; + std::string expr_str = ss.str(); + VLOG(6) << expr_str; + }; + + test_func(&ir_schedule); + test_func(&new_states[0]->ir_schedule); +} +*/ +class TestMultiLevelTiling : public TestAutoGenRuleBase { + public: + int fixed_rand_seed = 1; + std::vector default_input_names; + std::vector default_output_names; +}; + +TEST_F(TestMultiLevelTiling, Matmul) { + default_input_names = {"X", "Y"}; + default_output_names = {"temp_matmul_out"}; + std::vector X_shape = {32, 32}; + std::vector Y_shape = {32, 32}; + std::vector out_shape = {32, 32}; + + Initialize(common::DefaultNVGPUTarget()); + frontend::Program matmul_op = tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}}); + ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed); + SearchState state(ir_schedule); + VLOG(6) << "Original state:\n" << state->DebugString(); + + // Apply MultiLevelTiling + MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), + RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); + VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); + std::string ir = GetIR(new_states[0]->ir_schedule); + std::string expected_ir = R"ROC(Expr 0 { +{ + ScheduleBlock(root) + { + { + thread_bind[blockIdx.x] for (i_j_fused, 0, 4) + { + thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 1) + { + serial for (i_1, 0, 1) + { + serial for (j_1, 0, 1) + { + serial for (i_2, 0, 1) + { + serial for (j_2, 0, 1) + { + serial for (i_3, 0, 8) + { + serial for (j_3, 0, 32) + { + ScheduleBlock(temp_matmul_out__reduce_init) + { + i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))) + { + temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f + } + } + } + } + } + } + { + serial for (reduce_k_0, 0, 4) + { + serial for (ax0_0_ax1_0_fused, 0, 256) + { + ScheduleBlock(Y_reshape_shared_temp_buffer) + { + v0, v1 = axis.bind(((ax0_0_ax1_0_fused / 32) + (8 * reduce_k_0)), ((ax0_0_ax1_0_fused % 32) + (32 * j_1))) + attrs(compute_at_extra_var:ax0_0,ax1_0, cooperative_process:0) + { + Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1] + } + } + } + serial for (ax0_ax1_fused, 0, 64) + { + ScheduleBlock(X_reshape_shared_temp_buffer) + { + v0, v1 = axis.bind(((ax0_ax1_fused / 8) + ((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused)))), ((ax0_ax1_fused % 8) + (8 * reduce_k_0))) + attrs(compute_at_extra_var:ax0,ax1, cooperative_process:0) + { + X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1] + } + } + } + serial for (reduce_k_1, 0, 1) + { + serial for (i_2, 0, 1) + { + serial for (j_2, 0, 1) + { + serial for (reduce_k_2, 0, 8) + { + serial for (i_3, 0, 8) + { + serial for (j_3, 0, 32) + { + ScheduleBlock(temp_matmul_out_local_temp_buffer) + { + i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))) + read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)]) + write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)]) + { + temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))])) + } + } + } + } + } + } + } + } + } + serial for (ax0_1, 0, 8) + { + serial for (ax1_1, 0, 32) + { + ScheduleBlock(temp_matmul_out) + { + v0, v1 = axis.bind((((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused))) + ax0_1), ((32 * j_1) + ax1_1)) + attrs(reverse_compute_at_extra_var:ax0_1,ax1_1) + { + temp_matmul_out[v0, v1] = temp_matmul_out_local_temp_buffer[v0, v1] + } + } + } + } + } + } + } + } + } + } + } +} +} // end Expr 0 +)ROC"; + ASSERT_EQ(ir, expected_ir); + + // build ir::Module and debug source code + auto ir_module = BuildIRModule(new_states[0]->ir_schedule); + auto source_code = GenSourceCode(ir_module); + VLOG(6) << "scheduled source code:\n" << source_code; + + // execute and check precision + CheckResult( + GenExecutableKernel(ir_module), + GenExecutableKernel(BuildIRModule(MakeIRSchedule(matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))), + default_input_names, + default_output_names, + {X_shape, Y_shape}, + {out_shape}, + target_); +} + +TEST_F(TestMultiLevelTiling, ReduceSum) { + default_input_names = {"X"}; + default_output_names = {"var_0_tmp"}; + std::vector X_shape = {1, 16, 32}; + std::vector out_shape = {1, 16, 1}; + std::vector reduce_dim = {2}; + + Initialize(common::DefaultNVGPUTarget()); + frontend::Program reduce_sum_op = + tests::OpBuilder("reduce_sum").Build({{"X", X_shape}}, {{"dim", reduce_dim}, {"keep_dim", false}}); + ir::IRSchedule ir_schedule = MakeIRSchedule(reduce_sum_op); + SearchState state(ir_schedule); + VLOG(6) << "Original state:\n" << state->DebugString(); + + // Apply MultiLevelTiling + MultiLevelTiling multi_level_tiling(target_, MultiLevelTiling::kConfigs.at(target_.arch)); + // EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), RuleApplyType::kCannotApply); +} + +TEST_F(TestMultiLevelTiling, Pool2d) { + default_input_names = {"input"}; + default_output_names = {"var_0"}; + std::vector input_shape{2, 8, 16, 16}; + std::vector output_shape{2, 8, 8, 8}; + std::string pooling_type = "max"; + std::vector ksize{3, 3}; + std::vector strides{2, 2}; + std::vector paddings{1, 1, 1, 1}; + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + std::string data_format = "NCHW"; + bool adaptive = false; + std::string padding_algorithm = "EXPLICIT"; + frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}}, + {{"pool_type", pooling_type}, + {"kernel_size", ksize}, + {"stride_size", strides}, + {"padding_size", paddings}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}, + {"global_pooling", global_pooling}, + {"data_format", data_format}, + {"adaptive", adaptive}, + {"padding_algorithm", padding_algorithm}}); + + Initialize(common::DefaultNVGPUTarget()); + ir::IRSchedule ir_schedule = MakeIRSchedule(pool2d_program, fixed_rand_seed); + SearchState state(ir_schedule); + VLOG(6) << "Original state:\n" << state->DebugString(); + + // Apply MultiLevelTiling + MultiLevelTiling::Config mlt_config = { + /*bind_axis*/ std::vector{"blockIdx.x", "threadIdx.x"}, + /*tile_struct*/ std::string("SSRS"), + /*read_cache_memory_type*/ std::string("shared"), + /*read_cache_levels*/ std::vector{3}, + /*write_cache_memory_type*/ std::string("local"), + /*write_cache_levels*/ std::vector{2}, + }; + MultiLevelTiling multi_level_tiling(target_, mlt_config); + EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, default_output_names[0]), + RuleApplyType::kApplyAndPruneOtherRules); + auto new_states = multi_level_tiling.ApplyOnBlock(state, default_output_names[0]); + VLOG(6) << "After MultiLevelTiling, state:\n" << new_states[0]->DebugString(); + + std::string ir = GetIR(new_states[0]->ir_schedule); + std::string expected_ir = R"ROC(Expr 0 { +{ + ScheduleBlock(root) + { + serial for (i, 0, 2) + { + serial for (j, 0, 8) + { + serial for (k, 0, 18) + { + serial for (a, 0, 18) + { + ScheduleBlock(pad_temp_0) + { + i0, i1, i2, i3 = axis.bind(i, j, k, a) + pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) + } + } + } + } + } + } +} +} // end Expr 0 +Expr 1 { +{ + ScheduleBlock(root_0) + { + { + thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16) + { + thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) + { + serial for (i_1, 0, 1) + { + serial for (j_1, 0, 4) + { + serial for (k_1, 0, 1) + { + serial for (a_1, 0, 4) + { + ScheduleBlock(var_0__reduce_init) + { + i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) + { + var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f + } + } + } + } + } + } + { + serial for (kernel_idx, 0, 3) + { + serial for (kernel_idx_0, 0, 3) + { + serial for (ax0_ax1_ax2_ax3_fused, 0, 28) + { + ScheduleBlock(pad_temp_0_shared_temp_buffer) + { + v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) + attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0) + { + pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] + } + } + } + serial for (i_1, 0, 1) + { + serial for (j_1, 0, 4) + { + serial for (k_1, 0, 1) + { + serial for (a_1, 0, 4) + { + ScheduleBlock(var_0_local_temp_buffer) + { + i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) + read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) + write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) + { + var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) + } + } + } + } + } + } + } + } + serial for (ax0_0, 0, 1) + { + serial for (ax1_0, 0, 4) + { + serial for (ax2_0, 0, 1) + { + serial for (ax3_0, 0, 4) + { + ScheduleBlock(var_0) + { + v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) + attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) + { + var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] + } + } + } + } + } + } + } + } + } + } + } +} +} // end Expr 1 +)ROC"; + ASSERT_EQ(ir, expected_ir); + + // build ir::Module and debug source code + auto ir_module = BuildIRModule(new_states[0]->ir_schedule); + auto source_code = GenSourceCode(ir_module); + VLOG(6) << "scheduled source code:\n" << source_code; + + // execute and check precision + CheckResult(GenExecutableKernel(ir_module), + GenExecutableKernel( + BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), + default_input_names, + default_output_names, + {input_shape}, + {output_shape}, + target_); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc new file mode 100644 index 0000000000000..795a1bdc488fb --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h" + +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace auto_schedule { + +SkipRule::SkipRule(const common::Target& target) : AutoGenRule(target) {} + +RuleApplyType SkipRule::Init(ir::IRSchedule* ir_schedule) { + ir_schedule_ = ir_schedule; + num_applicable_ = 1; + return RuleApplyType::kApply; +} + +std::string SkipRule::GetRuleName() const { return "SkipRule"; } + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h new file mode 100644 index 0000000000000..0b7f26f2fdd8b --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +class SkipRule : public AutoGenRule { + public: + SkipRule(const common::Target& target); + ~SkipRule() = default; + + RuleApplyType Init(ir::IRSchedule* init_schedule) override; + + void Apply(int index) override {} + + std::string GetRuleName() const override; + + RuleApplyType AnalyseApplyType(SearchState state, const std::string& block_name) const override { + return RuleApplyType::kApply; + } + + std::vector ApplyOnBlock(SearchState state, const std::string& block_name) override { return {state}; } +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc new file mode 100644 index 0000000000000..9031605a7508c --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h" + +#include +#include + +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace auto_schedule { + +TEST(SkipRule, Basic) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + Expr M(32); + Expr N(128); + + Placeholder A("A", {M}); + Placeholder B("B", {N}); + + ir::Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); + + poly::StageMap stages = CreateStages({C}); + std::vector funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr before SkipRule: "; + VLOG(6) << ast_expr; + + SkipRule skip_rule(target); + ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); + SearchState state(ir_schedule, 0, {}); + + EXPECT_EQ(skip_rule.Init(&ir_schedule), RuleApplyType::kApply); + EXPECT_EQ(skip_rule.NumberApplicable(), 1); + skip_rule.ApplyRandomly(); + + // ApplyOnBlock + EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); + std::vector states = skip_rule.ApplyOnBlock(state, "C"); + + auto test_func = [&ast_expr](ir::IRSchedule* ir_sch) { + std::vector exprs = ir_sch->GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + EXPECT_EQ(ast_expr, exprs[0]); + }; + + test_func(&ir_schedule); + test_func(&states[0]->ir_schedule); +} + +TEST(SkipRule, ApplyOnSpecificBlock) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + Expr M(32); + Expr N(128); + + Placeholder A("A", {M}); + Placeholder B("B", {N}); + + ir::Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return A(i) + B(j); }, "C"); + + poly::StageMap stages = CreateStages({C}); + std::vector funcs = lang::LowerVec("TestSkipRule_Basic", stages, {C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Expr before SkipRule: "; + VLOG(6) << ast_expr; + + SkipRule skip_rule(target); + ir::IRSchedule ir_schedule(ir::ModuleExpr({ast_expr})); + SearchState state(ir_schedule, 0, {}); + + EXPECT_EQ(skip_rule.AnalyseApplyType(state, "C"), RuleApplyType::kApply); + std::vector states = skip_rule.ApplyOnBlock(state, "C"); + + std::vector exprs = states[0]->ir_schedule.GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + EXPECT_EQ(ast_expr, exprs[0]); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc new file mode 100644 index 0000000000000..9ad001a23bdcc --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h" + +#include +#include +#include +#include + +#include "cinn/auto_schedule/analysis/analyze_ir.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/cinn.h" +#include "cinn/frontend/optimize.h" +#include "cinn/hlir/framework/instruction.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/framework/tensor.h" +#include "cinn/optim/transform_gpu_forloop.h" +#ifdef CINN_WITH_CUDA +#include +#endif + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::Instruction; +using ::cinn::hlir::framework::Scope; +using ::cinn::hlir::framework::Shape; +using ::cinn::hlir::framework::Tensor; + +void TestAutoGenRuleBase::Initialize(const common::Target& target) { + target_ = target; + backend_compier_ = backends::Compiler::Create(target); +} + +ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test_program, + utils::LinearRandomEngine::StateType rand_seed, + bool apply_manual_schedule) { + Context::Global().ResetNameId(); + + auto graph = std::make_shared(test_program, target_); + hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); + LOG_IF(WARNING, graph->fusion_groups.size() > 1) << "Test Graph has more than 1 group"; + auto& dtype_dict = graph->GetMutableAttrs>("inferdtype"); + auto& shape_dict = graph->GetMutableAttrs>("infershape"); + hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); + + if (apply_manual_schedule) { + lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); + } else { + lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); + } + CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; + + std::vector bodys; + for (auto&& func : lowered_funcs_) { + bodys.emplace_back(func->body); + } + return ir::IRSchedule(ir::ModuleExpr({std::move(bodys)}), rand_seed); +} + +std::string TestAutoGenRuleBase::GetIR(const ir::IRSchedule& schedule) { + const auto& exprs = schedule.GetModule().GetExprs(); + std::stringstream module_stream; + for (auto i = 0; i < exprs.size(); ++i) { + module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr " << i << "\n"; + } + return module_stream.str(); +} + +ir::Module TestAutoGenRuleBase::BuildIRModule(const ir::IRSchedule& schedule) { + auto&& updated_bodys = schedule.GetModule().GetExprs(); + CHECK_EQ(lowered_funcs_.size(), updated_bodys.size()) << "associated exprs size not equal"; + + ir::Module::Builder builder("test_bulder", this->target_); + for (int i = 0; i < lowered_funcs_.size(); ++i) { + ir::Expr func_body = updated_bodys.at(i); + const ir::LoweredFunc& ori_func = lowered_funcs_.at(i); + auto&& new_func = UpdateFuncWithNewBody(target_, ori_func, func_body); + builder.AddFunction(new_func); + } + + return builder.Build(); +} + +std::string TestAutoGenRuleBase::GenSourceCode(const ir::Module& ir_module) { + std::unique_ptr codegen; +#ifdef CINN_WITH_CUDA + if (target_ == common::DefaultNVGPUTarget()) { + codegen = std::make_unique(this->target_); + } else { + codegen = std::make_unique(this->target_, CodeGenCX86::Feature::AVX512); + } +#else + codegen = std::make_unique(this->target_, CodeGenCX86::Feature::AVX512); +#endif + codegen->SetInlineBuiltinCodes(false); + return codegen->Compile(ir_module, CodeGenC::OutputKind::CImpl); +} + +raw_func_type TestAutoGenRuleBase::GenExecutableKernel(const ir::Module& ir_module) { + auto&& func_name = lowered_funcs_.front()->name; + // Compile to machine code + backend_compier_->Build(ir_module); + auto test_func_ptr = reinterpret_cast(backend_compier_->Lookup(func_name)); + return test_func_ptr; +} + +void MemoryCopy(const float* src, float* dst, int numel, std::string type) { +#ifdef CINN_WITH_CUDA + if (type == "DeviceToHost") { + cudaMemcpy(dst, src, numel * sizeof(float), cudaMemcpyDeviceToHost); + return; + } else if (type == "HostToDevice") { + cudaMemcpy(dst, src, numel * sizeof(float), cudaMemcpyHostToDevice); + return; + } +#endif + if (type == "HostToHost") { + for (size_t i = 0; i < numel; ++i) { + dst[i] = src[i]; + } + } else { + LOG(FATAL) << "Unknown memory copy type"; + } +} + +void AddDataToScope( + Scope* scope, const common::Target& target, float* data_ptr, std::string name, const std::vector& shape) { + auto* var = scope->Var(name); + auto& tensor = absl::get(*var); + CHECK(shape.size()) << "The size of shape can not be 0."; + Shape cinn_shape(shape); + tensor->Resize(cinn_shape); + auto* tgt_data_ptr = tensor->mutable_data(target); + std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; + MemoryCopy(data_ptr, tgt_data_ptr, cinn_shape.numel(), mem_cpy_type); +} + +void CheckResult(raw_func_type test_func, + raw_func_type expected_func, + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& input_shapes, + const std::vector>& output_shapes, + const common::Target& target) { + CHECK(input_names.size()) << "The number of inputs must be greater than 0."; + CHECK(output_names.size()) << "The number of outputs must be greater than 0."; + CHECK_EQ(input_names.size(), input_shapes.size()) << "The quantity of input_names and input_shapes must be equal."; + CHECK_EQ(output_names.size(), output_shapes.size()) + << "The quantity of output_names and output_shapes must be equal."; + + // Initialize data + std::vector input_data_ptrs(input_names.size()); + for (int i = 0; i < input_shapes.size(); ++i) { + int input_data_numel = + std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1, [](int a, int b) { return a * b; }); + input_data_ptrs[i] = reinterpret_cast(malloc(input_data_numel * sizeof(float))); + for (int j = 0; j < input_data_numel; ++j) { + input_data_ptrs[i][j] = (rand() * 1.f) / RAND_MAX; + } + } + std::vector test_output_data_ptrs(output_names.size()); + std::vector expected_output_data_ptrs(output_names.size()); + std::vector output_data_numels(output_shapes.size()); + for (int i = 0; i < output_shapes.size(); ++i) { + output_data_numels[i] = + std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1, [](int a, int b) { return a * b; }); + test_output_data_ptrs[i] = reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); + memset(test_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); + expected_output_data_ptrs[i] = reinterpret_cast(malloc(output_data_numels[i] * sizeof(float))); + memset(expected_output_data_ptrs[i], 0, output_data_numels[i] * sizeof(float)); + } + + auto launch_kernel_fn = [&](raw_func_type& raw_func, std::vector& output_data_ptrs) { + // Initialize scope + Scope scope; + // Initialize input data in scope. + for (int i = 0; i < input_names.size(); ++i) { + AddDataToScope(&scope, target, input_data_ptrs[i], input_names[i], input_shapes[i]); + } + // Initialize output data in scope. + for (int i = 0; i < output_names.size(); ++i) { + AddDataToScope(&scope, target, output_data_ptrs[i], output_names[i], output_shapes[i]); + } + + // Create Instruction and run + Instruction instr(target, &scope, input_names, output_names); + CHECK(raw_func) << "The raw_func can not be nullptr."; + instr.SetLoweredFunc(reinterpret_cast(raw_func)); + // should call Finalize explicitly before Run + instr.Finalize(); + instr.Run(); + + // data + for (int i = 0; i < output_names.size(); ++i) { + const float* result_ptr = scope.GetTensor(output_names[i])->data(); + std::string mem_cpy_type = target == common::DefaultNVGPUTarget() ? "DeviceToHost" : "HostToHost"; + MemoryCopy(result_ptr, output_data_ptrs[i], output_data_numels[i], mem_cpy_type); + } + }; + + // launch and execute test and expected kernel separately + launch_kernel_fn(test_func, test_output_data_ptrs); + launch_kernel_fn(expected_func, expected_output_data_ptrs); + + // Check result + for (int i = 0; i < output_shapes.size(); ++i) { + for (int j = 0; j < output_data_numels[i]; ++j) { + ASSERT_NEAR(test_output_data_ptrs[i][j], expected_output_data_ptrs[i][j], 1e-4); + } + } + + // Free memory + for (auto ptr : input_data_ptrs) { + free(ptr); + } + for (auto ptr : test_output_data_ptrs) { + free(ptr); + } + for (auto ptr : expected_output_data_ptrs) { + free(ptr); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h new file mode 100644 index 0000000000000..d8f8feb46babb --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "cinn/backends/compiler.h" +#include "cinn/common/target.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/scope.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace auto_schedule { + +/* @brief: Function pointer of executable code compiled by CINN. + * @params-1: Pointers to all arguments, including input and output. + * @params-2: The number of Arguments. + * @return: void + */ +using raw_func_type = void (*)(void**, int32_t); + +// A base utility class for testing AutoGenRule +class TestAutoGenRuleBase : public ::testing::Test { + public: + void SetUp() override { + srand(0); + Context::Global().ResetNameId(); + } + // Initialize context for specified target + void Initialize(const common::Target& target); + + // construct an ir::IRSchedule by lowering the specified for following AutoGenRule test + ir::IRSchedule MakeIRSchedule(const frontend::Program& test_program, + utils::LinearRandomEngine::StateType rand_seed = -1, + bool apply_manual_schedule = false); + + // Get the IR of bodies in IRSchedule + std::string GetIR(const ir::IRSchedule& schedule); + + // build ir::Module from the original lowered funcs with their bodies updated by the schedule + ir::Module BuildIRModule(const ir::IRSchedule& schedule); + + // generate source code with the built ir module + std::string GenSourceCode(const ir::Module& ir_module); + + // generate executable kernel function with the built ir module + raw_func_type GenExecutableKernel(const ir::Module& ir_module); + + protected: + common::Target target_; + std::vector lowered_funcs_; + std::unique_ptr backend_compier_; +}; + +/* @brief: Interface for checking function correctness. + * @params-1: Function pointer of the function to be tested. + * @params-2: Expected function pointer for comparison. + * @params-3: Names of input data. + * @params-4: Names of output data. + * @params-5: Shapes of the input data, each input corresponds to a std::vector. + * @params-6: Shapes of the output data, each output corresponds to a std::vector. + * @params-7: The Target expressing computing platform and architecture of the function to be tested. + * @return: void + */ +void CheckResult(raw_func_type test_func, + raw_func_type expected_func, + const std::vector& input_names, + const std::vector& output_names, + const std::vector>& input_shapes, + const std::vector>& output_shapes, + const common::Target& target); + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler.cc b/paddle/cinn/auto_schedule/search_space/block_sampler.cc new file mode 100644 index 0000000000000..66cfb8d7bfba1 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/block_sampler.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/block_sampler.h" + +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace auto_schedule { + +std::unique_ptr BlockSampler::Make(const std::vector& all_blocks, + bool default_remove_policy, + const std::string& strategy, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector& weights) { + CHECK_GT(all_blocks.size(), 0) << "Empty block list"; + if (strategy == "traversal") { + VLOG(6) << "Init TraversalBlockSampler with block num = " << all_blocks.size(); + return std::make_unique(all_blocks, default_remove_policy); + } else if (strategy == "probabilistic") { + VLOG(6) << "Init ProbabilisticBlockSampler with block num = " << all_blocks.size(); + return std::make_unique(all_blocks, default_remove_policy, rand_seed, weights); + } + + LOG(FATAL) << "Unimplemented strategy:" << strategy; + return nullptr; +} + +BlockSampler::BlockSampler(const std::vector& all_blocks, bool default_remove_policy) { + default_remove_policy_ = default_remove_policy; + std::transform(all_blocks.begin(), all_blocks.end(), std::back_inserter(all_blocks_), [](const ir::Expr& block_expr) { + const ir::ScheduleBlockRealize* block_realize = block_expr.As(); + const ir::ScheduleBlock* block = block_realize->schedule_block.As(); + return block->name; + }); +} + +std::string TraversalBlockSampler::NextBlock(bool remove) { + if (cur_idx_ < all_blocks_.size()) { + VLOG(6) << "[TraversalBlockSampler] next block: " << all_blocks_.at(cur_idx_); + std::string block_name = all_blocks_.at(cur_idx_); + if (remove) { + ++cur_idx_; + } + return block_name; + } + + VLOG(6) << "[TraversalBlockSampler] next block: empty"; + return ""; +} + +ProbabilisticBlockSampler::ProbabilisticBlockSampler(const std::vector& all_blocks, + bool default_remove_policy, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector& weights) + : BlockSampler(all_blocks, default_remove_policy), weights_(weights), rand_seed_(rand_seed) { + if (weights.empty()) { + weights_.resize(all_blocks.size(), 1); + } else { + CHECK_EQ(all_blocks.size(), weights_.size()); + } + remains_ = all_blocks.size(); +} + +std::string ProbabilisticBlockSampler::NextBlock(bool remove) { + if (remains_ == 0) { + return ""; + } + int block_idx = utils::SampleDiscreteFromDistribution(weights_, &rand_seed_); + if (remove) { + weights_[block_idx] = 0; + --remains_; + } + VLOG(6) << "[ProbabilisticBlockSampler] next block: " << all_blocks_.at(block_idx); + return all_blocks_.at(block_idx); +} + +} // namespace auto_schedule +} // namespace cinn \ No newline at end of file diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler.h b/paddle/cinn/auto_schedule/search_space/block_sampler.h new file mode 100644 index 0000000000000..7135afffb0280 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/block_sampler.h @@ -0,0 +1,115 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/ir/ir_base.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace auto_schedule { + +class SearchState; + +// Select the next block to be operated for SearchState during the search process +class BlockSampler { + public: + /** + * @brief Create a BlockSampler with the specific strategy name and necessary construct parameters. + * @param all_blocks All possible blocks to be sampled. + * @param default_remove_policy The default option to determine whether to delete the next block after selecting it. + * @param strategy The block sampling strategy. + * Currently, the available strategies are "traversal" and "probabilistic", + * where "traversal" means to select blocks one by one until all blocks are traversed, + * and "probabilistic" means randomly picking blocks according to the given distribution. + * @param weights Used for the probabilistic policy, giving each candidate a weight. + */ + static std::unique_ptr Make(const std::vector& all_blocks, + bool default_remove_policy = true, + const std::string& strategy = "traversal", + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); + + // Return the name of sample strategy + virtual const char* Name() const = 0; + + // Reset associated states to sample at the beginning + virtual void Reset() = 0; + + // Select a block with default remove policy. + std::string NextBlock() { return NextBlock(default_remove_policy_); } + + protected: + // A BlockSampler object should be created with the static function Make() + BlockSampler(const std::vector& all_blocks, bool default_remove_policy); + + // Select a block to apply rule + // The param remove is used to determine whether to delete the next block after selecting it, + // If remove == true, it will not be sampled in the future. + virtual std::string NextBlock(bool remove) = 0; + + // The names of all blocks + // Because the Block Expr will be changed in the search process, the name is saved for indexing + std::vector all_blocks_; + + // The default policy to determine whether to delete the next block after selecting it. + bool default_remove_policy_; +}; + +// Sample blocks with traversal strategy, +// witch means to select blocks one by one until all blocks are traversed. +class TraversalBlockSampler : public BlockSampler { + public: + TraversalBlockSampler(const std::vector& all_blocks, bool default_remove_policy) + : BlockSampler(all_blocks, default_remove_policy), cur_idx_(0) {} + + const char* Name() const override { return "traversal"; } + + void Reset() override { cur_idx_ = 0; } + + private: + std::string NextBlock(bool remove) override; + + private: + int cur_idx_; +}; + +// Sample blocks with probabilistic strategy, +// witch means randomly picking blocks according to the given distribution. +class ProbabilisticBlockSampler : public BlockSampler { + public: + ProbabilisticBlockSampler(const std::vector& all_blocks, + bool default_remove_policy, + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); + + const char* Name() const override { return "probabilistic"; } + + void Reset() override {} + + private: + std::string NextBlock(bool remove) override; + + private: + std::vector weights_; + utils::LinearRandomEngine::StateType rand_seed_; + int remains_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc b/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc new file mode 100644 index 0000000000000..ef07d964dd153 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/block_sampler_test.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/block_sampler.h" + +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace auto_schedule { + +std::vector CreateTestBlocks() { + std::vector blocks; + for (int i = 0; i < 3; ++i) { + ir::Expr block = ir::ScheduleBlock::Make({}, {}, {}, "block_" + std::to_string(i), ir::Expr()); + blocks.push_back(ir::ScheduleBlockRealize::Make({}, block)); + } + return blocks; +} + +TEST(BlockSampler, Make) { + std::vector mock_blocks = CreateTestBlocks(); + auto traversal_block_sampler = BlockSampler::Make(mock_blocks, true, "traversal"); + ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); + auto probabilistic_block_sampler = BlockSampler::Make(mock_blocks, true, "probabilistic"); + ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); +} + +TEST(TraversalBlockSampler, NextBlock) { + std::vector blocks = CreateTestBlocks(); + auto traversal_block_sampler = BlockSampler::Make(blocks, true, "traversal"); + ASSERT_EQ("block_0", traversal_block_sampler->NextBlock()); + ASSERT_EQ("block_1", traversal_block_sampler->NextBlock()); + ASSERT_EQ("block_2", traversal_block_sampler->NextBlock()); + ASSERT_EQ("", traversal_block_sampler->NextBlock()); + traversal_block_sampler->Reset(); + ASSERT_EQ("block_0", traversal_block_sampler->NextBlock()); + + traversal_block_sampler = BlockSampler::Make(blocks, false, "traversal"); + ASSERT_EQ("block_0", traversal_block_sampler->NextBlock()); + ASSERT_EQ("block_0", traversal_block_sampler->NextBlock()); +} + +TEST(ProbabilisticBlockSampler, NextBlock) { + std::vector blocks = CreateTestBlocks(); + auto probabilistic_block_sampler = BlockSampler::Make(blocks, false, "probabilistic", 0, {4, 2, 1}); + std::string block_name; + for (int i = 0; i < 20; ++i) { + block_name = probabilistic_block_sampler->NextBlock(); + VLOG(6) << "next block name: " << block_name; + } + + probabilistic_block_sampler = BlockSampler::Make(blocks, true, "probabilistic", 0, {4, 2, 1}); + probabilistic_block_sampler->NextBlock(); + probabilistic_block_sampler->NextBlock(); + probabilistic_block_sampler->NextBlock(); + ASSERT_EQ("", probabilistic_block_sampler->NextBlock()); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler.cc b/paddle/cinn/auto_schedule/search_space/rule_sampler.cc new file mode 100644 index 0000000000000..3951af427081f --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/rule_sampler.h" + +#include +#include + +namespace cinn { +namespace auto_schedule { + +std::unique_ptr RuleSampler::Make(const std::vector& potential_rules, + bool default_remove_policy, + const std::string& strategy, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector& weights) { + CHECK_GT(potential_rules.size(), 0) << "Empty rule list"; + if (strategy == "traversal") { + return std::make_unique(potential_rules, default_remove_policy); + } else if (strategy == "probabilistic") { + return std::make_unique(potential_rules, default_remove_policy, rand_seed, weights); + } + + LOG(FATAL) << "Unimplemented strategy:" << strategy; + return nullptr; +} + +AutoGenRule* TraversalRuleSampler::NextRule(bool remove) { + if (cur_idx_ < potential_rules_->size()) { + AutoGenRule* rule = potential_rules_->at(cur_idx_); + if (remove) { + ++cur_idx_; + } + return rule; + } + + return nullptr; +} + +ProbabilisticRuleSampler::ProbabilisticRuleSampler(const std::vector& potential_rules, + bool default_remove_policy, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector& weights) + : RuleSampler(potential_rules, default_remove_policy), + weights_(weights), + rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { + if (weights.empty()) { + weights_.resize(potential_rules.size(), 1); + } else { + CHECK_EQ(potential_rules.size(), weights_.size()); + } + remains_ = potential_rules.size(); +} + +AutoGenRule* ProbabilisticRuleSampler::NextRule(bool remove) { + if (remains_ == 0) { + return nullptr; + } + int rule_idx = utils::SampleDiscreteFromDistribution(weights_, &rand_seed_); + if (remove) { + weights_[rule_idx] = 0; + --remains_; + } + + return potential_rules_->at(rule_idx); +} + +} // namespace auto_schedule +} // namespace cinn \ No newline at end of file diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler.h b/paddle/cinn/auto_schedule/search_space/rule_sampler.h new file mode 100644 index 0000000000000..828e4a775eeb1 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler.h @@ -0,0 +1,114 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace auto_schedule { + +class SearchState; + +// Select the next potential rule for the SearchState during the search process. +class RuleSampler { + public: + /** + * @brief Create a RuleSampler with the specific strategy name and necessary construct parameters. + * @param potential_rules All possible rules to be sampled. + * @param default_remove_policy The default option to determine whether to delete the next block after selecting it. + * @param strategy The rule sampling strategy. + * Currently, the available strategies are "traversal" and "probabilistic", + * where "traversal" means to select rules one by one until all rules are traversed, + * and "probabilistic" means randomly picking rules according to the given distribution. + * @param weights Used for the probabilistic policy, giving each candidate a weight. + */ + static std::unique_ptr Make(const std::vector& potential_rules, + bool default_remove_policy = true, + const std::string& strategy = "traversal", + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); + // Return the name of sample strategy + virtual const char* Name() const = 0; + + // Reset associated states to sample at the beginning + virtual void Reset() = 0; + + // Select a rule with default remove policy. + AutoGenRule* NextRule() { return NextRule(default_remove_policy_); } + + protected: + // A RuleSampler object should be created with the static function Make() + RuleSampler(const std::vector& potential_rules, bool default_remove_policy) + : potential_rules_(&potential_rules), default_remove_policy_(default_remove_policy) {} + + // Select a rule to apply. + // The param remove is used to determine whether to delete the next rule after selecting it, + // If remove == true, it will not be sampled in the future. + virtual AutoGenRule* NextRule(bool remove) = 0; + + // The pointer refers to all potential rules + const std::vector* potential_rules_; + + // The default policy to determine whether to delete the next rule after selecting it. + bool default_remove_policy_; +}; + +// Sample rules with traversal strategy, +// witch means to select rules one by one until all rules are traversed. +class TraversalRuleSampler : public RuleSampler { + public: + TraversalRuleSampler(const std::vector& potential_rules, bool default_remove_policy) + : RuleSampler(potential_rules, default_remove_policy), cur_idx_(0) {} + + const char* Name() const override { return "traversal"; } + + void Reset() override { cur_idx_ = 0; } + + private: + AutoGenRule* NextRule(bool remove) override; + + private: + int cur_idx_; +}; + +// Sample rules with probabilistic strategy, +// which means randomly picking rules according to the given distribution. +class ProbabilisticRuleSampler : public RuleSampler { + public: + ProbabilisticRuleSampler(const std::vector& potential_rules, + bool default_remove_policy, + utils::LinearRandomEngine::StateType rand_seed = 0, + const std::vector& weights = {}); + + const char* Name() const override { return "probabilistic"; } + + void Reset() override {} + + private: + AutoGenRule* NextRule(bool remove) override; + + private: + std::vector weights_; + utils::LinearRandomEngine::StateType rand_seed_; + int remains_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc b/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc new file mode 100644 index 0000000000000..91ca4fd5926b0 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler_test.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/rule_sampler.h" + +#include + +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h" + +namespace cinn { +namespace auto_schedule { + +#ifdef CINN_WITH_CUDA +Target target = common::DefaultNVGPUTarget(); +#else +Target target = common::DefaultHostTarget(); +#endif + +std::vector GenerateTestRules() { return {new AutoUnroll(target), new SkipRule(target)}; } + +TEST(RuleSampler, Make) { + std::vector rules = GenerateTestRules(); + auto traversal_block_sampler = RuleSampler::Make(rules, true, "traversal"); + ASSERT_STREQ(traversal_block_sampler->Name(), "traversal"); + auto probabilistic_block_sampler = RuleSampler::Make(rules, true, "probabilistic"); + ASSERT_STREQ(probabilistic_block_sampler->Name(), "probabilistic"); +} + +TEST(TraversalRuleSampler, NextRule) { + std::vector rules = GenerateTestRules(); + auto traversal_rule_sampler = RuleSampler::Make(rules, true, "traversal"); + AutoGenRule* rule = traversal_rule_sampler->NextRule(); + ASSERT_EQ("AutoUnroll", rule->GetRuleName()); + rule = traversal_rule_sampler->NextRule(); + ASSERT_EQ("SkipRule", rule->GetRuleName()); + traversal_rule_sampler->Reset(); + rule = traversal_rule_sampler->NextRule(); + ASSERT_EQ("AutoUnroll", rule->GetRuleName()); + + traversal_rule_sampler = RuleSampler::Make(rules, false, "traversal"); + rule = traversal_rule_sampler->NextRule(); + ASSERT_EQ("AutoUnroll", rule->GetRuleName()); + rule = traversal_rule_sampler->NextRule(); + ASSERT_EQ("AutoUnroll", rule->GetRuleName()); +} + +TEST(ProbabilisticRuleSampler, NextRule) { + std::vector rules = GenerateTestRules(); + auto probabilistic_rule_sampler = RuleSampler::Make(rules, false, "probabilistic", 0, {4, 1}); + AutoGenRule* rule; + for (int i = 0; i < 20; ++i) { + rule = probabilistic_rule_sampler->NextRule(); + VLOG(6) << "next rule name: " << rule->GetRuleName(); + } + + probabilistic_rule_sampler = RuleSampler::Make(rules, true, "probabilistic", 0, {4, 1}); + probabilistic_rule_sampler->NextRule(); + probabilistic_rule_sampler->NextRule(); + ASSERT_EQ(nullptr, probabilistic_rule_sampler->NextRule()); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_space.cc b/paddle/cinn/auto_schedule/search_space/search_space.cc new file mode 100644 index 0000000000000..af10da2215100 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_space.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/search_space.h" + +#include + +#include +#include +#include + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h" +#include "cinn/auto_schedule/search_space/block_sampler.h" +#include "cinn/auto_schedule/search_space/rule_sampler.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(auto_schedule_use_cost_model); + +namespace cinn { +namespace auto_schedule { + +SearchSpace::SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed) + : tune_task_(tune_task), rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) { + const auto& target = tune_task_.target; + // initialize a set of rules and they are commonly used by all states + // TODO(zhhsplendid): pass correct output names to AutoInline + // sketch_rules_.emplace_back(new AutoInline(target, tune_task_.output_names)); + sketch_rules_.emplace_back(new MultiLevelTiling(target, MultiLevelTiling::kConfigs.at(target.arch))); + sketch_rules_.emplace_back(new AutoUnroll(target)); + sketch_rules_.emplace_back(new SkipRule(target)); +} + +SearchState SearchSpace::GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model) { + bool has_manual_schedule = false; + if (has_manual_schedule) { + SearchState ret = ManualScheduleMutate(state); + return ret; + } + SearchState ret = RandomScheduleMutate(state); + if (FLAGS_auto_schedule_use_cost_model) { + ret->predicted_cost = cost_model.Predict(ret->ir_schedule.GetModule(), tune_task_.target); + } + VLOG(4) << JoinStatesDebugString("SearchSpace::GetScheduleMutate", {state}, /*verbose=*/VLOG_IS_ON(5)); + return ret; +} + +SearchState SearchSpace::ManualScheduleMutate(const SearchState& state) { + // TODO(zhhsplendid): Add manual schedule mutate + return state; +} + +SearchState SearchSpace::RandomScheduleMutate(const SearchState& state) { + // 1. Found the schedules which can apply on this Expr + // 2. Make a distribution on those schedules + std::map weight_to_rule_index; + int cur_weight = 0; + SearchState ret(state); + std::vector apply_types(ret->applicable_rules.size()); + for (int idx = 0; idx != ret->applicable_rules.size(); ++idx) { + AutoGenRule* rule = ret->applicable_rules.at(idx); + RuleApplyType apply_type = rule->Init(&ret->ir_schedule); + VLOG(6) << "Evaluate rule:" << rule->GetRuleName() << "=" << static_cast(apply_type); + apply_types[idx] = apply_type; + if (apply_type != RuleApplyType::kCannotApply) { + weight_to_rule_index[cur_weight] = idx; + cur_weight += rule->NumberApplicable(); + } + } + + if (weight_to_rule_index.empty()) { + // No applicable rule, return the input mod_expr + VLOG(6) << "No applicable rule"; + return ret; + } + + // 3. Sample a schedule on the distribution + int sample_weighted_index = utils::SampleUniformInt(0, cur_weight, &rand_seed_); + + auto iter = weight_to_rule_index.upper_bound(sample_weighted_index); + --iter; + + int sample_rule_index = iter->second; + CHECK_LT(sample_rule_index, ret->applicable_rules.size()); + AutoGenRule* sample_rule = ret->applicable_rules.at(sample_rule_index); + VLOG(7) << "Apply rule: " << sample_rule->GetRuleName() << " with index=" << sample_weighted_index - iter->first; + // 4. Apply the schedule change + sample_rule->Apply(sample_weighted_index - iter->first); + + // 5. Remove the rule after applying it + if (apply_types.at(sample_rule_index) != RuleApplyType::kCannotApply) { + ret->applicable_rules.erase(ret->applicable_rules.begin() + sample_rule_index); + } + + return ret; +} + +std::vector SearchSpace::InitSketchWithRandomStrategy(int num) { + VLOG(5) << "SearchSpace::GetRandomInitialSketch with num=" << num; + ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); + std::vector init_rules; + std::transform(sketch_rules_.begin(), sketch_rules_.end(), std::back_inserter(init_rules), [](const auto& rule) { + return rule.get(); + }); + std::vector result; + while (result.size() < num) { + SearchState state(init_schedule, SearchState::NOT_INIT_COST, init_rules); + for (int i = 0; i < init_sketch_random_depth_; ++i) { + VLOG(6) << "Generating random sketch with RandomScheduleMutate at depth: " << i; + state = RandomScheduleMutate(state); + if (state->applicable_rules.empty()) { + break; + } + } + + VLOG(5) << JoinStatesDebugString( + "SearchSpace::GetRandomInitialSketch-New_Sketch", {state}, /*verbose=*/VLOG_IS_ON(6)); + result.emplace_back(std::move(state)); + } + return result; +} + +std::vector SearchSpace::InitSketchWithRandomPrunedStrategy() { + VLOG(5) << "SearchSpace::InitSketchWithRandomPrunedStrategy"; + ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); + auto all_blocks = init_schedule.GetAllBlocks(); + auto block_sampler = BlockSampler::Make(all_blocks, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); + + std::vector init_rules; + std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { + return rule.get(); + }); + CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; + + SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); + std::vector states_buf1{init_state}, states_buf2; + std::vector* p_states_cur = &states_buf1; + std::vector* p_states_next = &states_buf2; + int total_steps = 0, steps; + std::string block_name; + while ("" != (block_name = block_sampler->NextBlock()) && total_steps < init_sketch_random_depth_) { + steps = utils::SampleUniformInt(1, init_rules.size() + 1, &rand_seed_); + if (total_steps + steps > init_sketch_random_depth_) { + steps = init_sketch_random_depth_ - total_steps; + } + total_steps += steps; + p_states_next->clear(); + for (const auto& state : *p_states_cur) { + auto rule_sampler = RuleSampler::Make(init_rules, true, "probabilistic", utils::ForkRandomState(&rand_seed_)); + auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), steps, false, 1); + p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); + } + std::swap(p_states_cur, p_states_next); + } + VLOG(5) << JoinStatesDebugString( + "SearchSpace::InitSketchWithRandomPrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); + return *p_states_cur; +} + +std::vector SearchSpace::InitSketchWithRulePrunedStrategy() { + VLOG(5) << "SearchSpace::InitSketchWithRulePrunedStrategy"; + ir::IRSchedule init_schedule(ir::ModuleExpr(tune_task_.GetLoweredFuncBodyExprs()), + utils::ForkRandomState(&rand_seed_)); + auto all_blocks = init_schedule.GetAllBlocks(); + std::reverse(all_blocks.begin(), all_blocks.end()); + auto block_sampler = BlockSampler::Make(all_blocks, true, "traversal"); + + std::vector init_rules; + std::transform(sketch_rules_.begin(), sketch_rules_.end() - 1, std::back_inserter(init_rules), [](const auto& rule) { + return rule.get(); + }); + CHECK(init_rules.size() > 0) << "number of init rules cannot be 0"; + + SearchState init_state(init_schedule, SearchState::NOT_INIT_COST, {}); + std::vector states_buf1{init_state}, states_buf2; + std::vector* p_states_cur = &states_buf1; + std::vector* p_states_next = &states_buf2; + std::string block_name; + while ("" != (block_name = block_sampler->NextBlock())) { + p_states_next->clear(); + for (const auto& state : *p_states_cur) { + auto rule_sampler = RuleSampler::Make(init_rules, true, "traversal"); + auto new_states = ApplySketchRule(state, block_name, rule_sampler.get(), 0, true); + p_states_next->insert(p_states_next->end(), new_states.begin(), new_states.end()); + } + std::swap(p_states_cur, p_states_next); + } + VLOG(5) << JoinStatesDebugString( + "SearchSpace::InitSketchWithRulePrunedStrategy", *p_states_cur, /*verbose=*/VLOG_IS_ON(6)); + return *p_states_cur; +} + +std::vector SearchSpace::GenerateSketches(int num, const std::string& strategy) { + VLOG(4) << "SearchSpace::GenerateSketches with num = " << num; + + if (strategy == "random") { + return InitSketchWithRandomStrategy(num); + } + + std::vector result; + while (result.size() < num) { + std::vector sketchs; + if (strategy == "rule_prune") { + sketchs = InitSketchWithRulePrunedStrategy(); + } else if (strategy == "random_prune") { + sketchs = InitSketchWithRandomPrunedStrategy(); + } else { + LOG(FATAL) << "Unimplemented init sketch strategy"; + } + + // the more rules are applied, the greater the possibility of good results, + // the more rules are applied, the more they are saved behind the queue, + // so we give priority to the results in the rear + for (auto iter = sketchs.rbegin(); iter != sketchs.rend(); ++iter) { + result.push_back(*iter); + if (result.size() == num) { + break; + } + } + } + VLOG(4) << JoinStatesDebugString("SearchSpace::GenerateSketches", result, /*verbose=*/VLOG_IS_ON(5)); + return result; +} + +std::vector SearchSpace::ApplySketchRule(const SearchState& state, + const std::string& block_name, + RuleSampler* rule_sampler, + int steps, + bool prune_by_rule, + double prune_probability) { + std::list layer{state}; + int step = 0; + AutoGenRule* rule; + // After determining a SearchState and a block, each rule has two possibilities: apply and not apply. + // In all transfer spaces, select a rule at each step, and collect all possible new states arrived by apply and not + // apply. This forms a tree, and we can use rule pruning or random pruning to reduce the number of sketches. + VLOG(6) << "Collect the states of all transfers within steps: " << steps; + while ((step++ < steps || steps == 0) && (rule = rule_sampler->NextRule())) { + VLOG(7) << "step = " << step << ", rule: " << rule->GetRuleName(); + std::list new_states; + int id = 0; + for (std::list::iterator iter = layer.begin(); iter != layer.end();) { + // Some rules will reduce the number of blocks, such as AutoInline, + // so we need to check whether the SearchState still has the block. + if (!(*iter)->ir_schedule.HasBlock(block_name)) { + ++iter; + continue; + } + auto type = rule->AnalyseApplyType(*iter, block_name); + VLOG(7) << "At SearchState " << ++id + << ", apply type = " << static_cast::type>(type); + // if cannot apply the rule, skip it + if (type == RuleApplyType::kCannotApply) { + ++iter; + continue; + } + // if can apply the rule, apply it and determine whether to prune the branch that do not apply + std::vector tmp_states = rule->ApplyOnBlock(*iter, block_name); + new_states.insert(new_states.end(), tmp_states.begin(), tmp_states.end()); + bool need_prune = false; + if (prune_by_rule) { + need_prune = (type == RuleApplyType::kApplyAndPruneOtherRules); + } else { + need_prune = (utils::SampleUniformDouble(0, 1, &rand_seed_) < prune_probability); + } + if (need_prune) { + iter = layer.erase(iter); + } else { + ++iter; + } + } + VLOG(7) << "apply on block: " << block_name << ", generate " << new_states.size() << " new states at step " << step; + layer.splice(layer.end(), std::move(new_states)); + } + VLOG(6) << "apply on block: " << block_name << ", generate " << layer.size() - 1 << " more states at all"; + return std::vector(layer.begin(), layer.end()); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_space.h b/paddle/cinn/auto_schedule/search_space/search_space.h new file mode 100644 index 0000000000000..afa87174ca2c9 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_space.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h" +#include "cinn/auto_schedule/search_space/rule_sampler.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * This class is an abstraction of the transformations can be applied to + * ir::Expr during auto-tuning. The transformation can be: + * + * 1. Manual defined schedule + * 2. Schedule generated by AutoGenRule + * + * TODO(zhhsplendid): de-duplication the generated ModuleExpr + */ +class SearchSpace { + public: + SearchSpace(const TuneTask& tune_task, utils::LinearRandomEngine::StateType rand_seed = -1); + + // Sketch mutate, returns the mutated ModuleExpr and estimited cost + virtual SearchState GetScheduleMutate(const SearchState& state, const ExprCostModel& cost_model); + + /** + * \brief Generate sketch as initial population of evolutionary search. + * @param num The number of sketches to generate. + * @param strategy The strategy to generate sketchs, + * Current optional strategies are "rule_prune" or "random_prune" or "random". + * - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. + * - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, + * and supports the function of a rule returning multiple SearchStates and random pruning by probability. + * - "random": will randomly select a block and a rule to apply and repeat this step several times, + * however, each rule can only be used on one SearchState at most once. + * @return Generated sketchs. + */ + virtual std::vector GenerateSketches(int num, const std::string& strategy); + + private: + // TODO(zhhsplendid): mutate by manual schedule. + SearchState ManualScheduleMutate(const SearchState& state); + + // mutate by sketch rules randomly + SearchState RandomScheduleMutate(const SearchState& state); + + // Generate num sketchs, each with several rounds of SketchMutate + std::vector InitSketchWithRandomStrategy(int num); + + // Generate sketch pruned randomly as initial population of evolutionary search + std::vector InitSketchWithRandomPrunedStrategy(); + + // Generate sketch pruned by rules as initial population of evolutionary search + std::vector InitSketchWithRulePrunedStrategy(); + + /** + * @brief Collect the new states that may be transferred to after applying several rules on a block from a certain + * state. + * @param state Starting point of state transition. + * @param block_name Name of the block to apply the rules to. + * @param rule_sampler Sampler that samples the new rule to apply on the block. + * @param steps Number of steps to apply the rule. + * @param prune_by_rule If true, prune the state transition tree by rule, otherwise prune randomly. + * @param prune_probability Pruning probability of random pruning. + */ + std::vector ApplySketchRule(const SearchState& state, + const std::string& block_name, + RuleSampler* rule_sampler, + int steps, + bool prune_by_rule, + double prune_probability = 1); + + private: + const TuneTask& tune_task_; + int init_sketch_random_depth_ = 6; + // supported AutoGenRules, every task holds a set + std::vector> sketch_rules_; + utils::LinearRandomEngine::StateType rand_seed_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_space_test.cc b/paddle/cinn/auto_schedule/search_space/search_space_test.cc new file mode 100644 index 0000000000000..2e1064ba7f929 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_space_test.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/search_space.h" + +#include + +namespace cinn { +namespace auto_schedule {} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_state.cc b/paddle/cinn/auto_schedule/search_space/search_state.cc new file mode 100644 index 0000000000000..48f9e8532085f --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_state.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/search_state.h" + +#include +#include +#include +#include + +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/utils/functional.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace auto_schedule { + +SearchState::SearchState(ir::IRSchedule ir_sch, float cost, const std::vector& rules) + : common::Shared<_SearchState_>(common::make_shared<_SearchState_>()) { + auto* state = get(); + state->ir_schedule = std::move(ir_sch); + state->applicable_rules = rules; + state->predicted_cost = cost; +} + +SearchState SearchState::Copy() const { return SearchState((*this)->ir_schedule, (*this)->predicted_cost, {}); } + +std::string _SearchState_::DebugString() const { + const auto& exprs = ir_schedule.GetModule().GetExprs(); + std::stringstream module_stream; + for (auto i = 0; i < exprs.size(); ++i) { + module_stream << "Expr " << i << " {\n" << exprs.at(i) << "\n} // end Expr"; + } + + const char* fmt_str = R"ROC( +ModuleExpr { +%s +} // end ModuleExpr +ScheduleDesc { +%s +} // end ScheduleDesc +predicted_cost: %f)ROC"; + + return utils::StringFormat( + fmt_str, module_stream.str().c_str(), ir_schedule.GetTraceDesc().DebugString().c_str(), predicted_cost); +} + +bool operator<(const SearchState& left, const SearchState& right) { + return left->predicted_cost < right->predicted_cost; +} + +// Visit every node by expanding all of their fields in dfs order +class DfsWithExprsFields : public ir::IRVisitor { + protected: +#define __m(t__) \ + void Visit(const ir::t__* x) override { \ + for (auto* n : x->expr_fields()) { \ + if (n->defined()) { \ + Visit(n); \ + } \ + } \ + } + + NODETY_FORALL(__m) +#undef __m + + void Visit(const Expr* expr) override { IRVisitor::Visit(expr); } +}; + +// Generate a reduce hash of a AST tree by combining hash of each AST node +class IrNodesStructuralHash : public DfsWithExprsFields { + public: + IrNodesStructuralHash(size_t init_key) : hash_key_(init_key) {} + size_t operator()(const Expr* expr) { + Visit(expr); + return hash_key_; + } + + void Visit(const Expr* expr) override { + static decltype(ir::kIrNodeTyReprs) Node2Name = ir::kIrNodeTyReprs; + if (!expr->defined()) return; + auto type_code = static_cast(expr->node_type()); + hash_key_ = utils::HashCombine(hash_key_, type_code); + DfsWithExprsFields::Visit(expr); + } + + private: + void Visit(const ir::_Tensor_* x) override { + for (auto& e : x->shape) { + Visit(&e); + } + DfsWithExprsFields::Visit(x->buffer.As()); + } + + using IrNodeTyUnderlyingType = std::underlying_type::type; + size_t hash_key_; +}; + +size_t SearchStateHash::operator()(const SearchState& s) const { + size_t hash_key = 0; + const auto& exprs = s->ir_schedule.GetModule().GetExprs(); + for (auto&& expr : exprs) { + hash_key = IrNodesStructuralHash(hash_key)(&expr); + } + return hash_key; +} + +bool SearchStateEqual::operator()(const SearchState& lhs, const SearchState& rhs) const { + const auto& lhs_exprs = lhs->ir_schedule.GetModule().GetExprs(); + const auto& rhs_exprs = rhs->ir_schedule.GetModule().GetExprs(); + // compare exprs size firstly + if (lhs_exprs.size() != rhs_exprs.size()) return false; + + // compare every expr one by one with ir::IrEqualVisitor + for (int i = 0; i < lhs_exprs.size(); ++i) { + ir::IrEqualVisitor compartor(/*allow_name_suffix_diff=*/true); // ignore suffix difference in name + if (!compartor.Compare(lhs_exprs[i], rhs_exprs[i])) return false; + } + return true; +} + +std::string JoinStatesDebugString(const std::string& title, const std::vector& states, bool verbose) { + std::stringstream ss; + ss << title << " states size:" << states.size() << "\n"; + SearchStateHash state_hasher; + for (size_t i = 0; i < states.size(); ++i) { + uint64_t hash_key = state_hasher(states[i]); + if (verbose) { + ss << "\tState-" << i << " hash:" << hash_key << "\t content:------>" << states[i]->DebugString() << "\n<------"; + } else { + ss << "\tState-" << i << " hash:" << hash_key << "\n"; + } + } + return std::move(*ss.rdbuf()).str(); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_state.h b/paddle/cinn/auto_schedule/search_space/search_state.h new file mode 100644 index 0000000000000..db2bfa3f7e276 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_state.h @@ -0,0 +1,87 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/common/object.h" +#include "cinn/common/shared.h" +#include "cinn/ir/ir_compare.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/ir_visitor.h" + +namespace cinn { +namespace auto_schedule { + +struct _SearchState_; +class AutoGenRule; + +//! Shared Wrapper for _SearchState_ +class SearchState : public common::Shared<_SearchState_> { + public: + SearchState() = default; + // create a new SearchState + explicit SearchState(ir::IRSchedule ir_sch, float cost = NOT_INIT_COST, const std::vector& rules = {}); + + // Constant standing for a cost not being initialized + static constexpr float NOT_INIT_COST = std::numeric_limits::max(); + // compare function for two states + friend bool operator<(const SearchState& left, const SearchState& right); + + // Deep copy a SearchState + SearchState Copy() const; +}; + +//! Class to store immediate states during search +struct _SearchState_ : public common::Object { + // IRSchedule contains ir::ModuleExpr and trace scheduling process + ir::IRSchedule ir_schedule; + // Cost model predicted cost + float predicted_cost; + // The rules that can be applied to the IRSchedule at this state. + std::vector applicable_rules; + + // return detail string of content for debug; + std::string DebugString() const; + + const char* type_info() const override { return __type_info__; } + static constexpr char* __type_info__ = "auto_schedule_state"; +}; + +// SearchStateHash hash functor that visits every AST node and combine their hash of node_type in dfs order +struct SearchStateHash { + size_t operator()(const SearchState& s) const; +}; + +// SearchStateHash equal functor, use ir::IrEqualVisitor to compare their AST struct and fields +struct SearchStateEqual { + bool operator()(const SearchState& lhs, const SearchState& rhs) const; +}; + +/*! + * \brief concatenate debug strings of all states with additional info + * \param title head of the result string + * \param states SearchState array to be debugged + * \param verbose whether to enable more verbose debug info + * \return the concatenated debug string + */ +std::string JoinStatesDebugString(const std::string& title, + const std::vector& states, + bool verbose = false); + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_space/search_state_test.cc b/paddle/cinn/auto_schedule/search_space/search_state_test.cc new file mode 100644 index 0000000000000..598fc95317589 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_space/search_state_test.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_space/search_state.h" + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/common/context.h" + +namespace cinn { +namespace auto_schedule { + +TEST(TestSearchState, SearchStateHash_Equal) { + Target target = common::DefaultHostTarget(); + + ir::Expr M(32); + ir::Expr N(32); + + lang::Placeholder A("A", {M, N}); + ir::Tensor B = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) + ir::Expr(2.f); }, "B"); + ir::Tensor C = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + + cinn::common::Context::Global().ResetNameId(); + auto a_plus_const_funcs_1 = + lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + + cinn::common::Context::Global().ResetNameId(); + auto a_plus_const_funcs_2 = + lang::LowerVec("A_plus_const", poly::CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + + cinn::common::Context::Global().ResetNameId(); + auto a_plus_b_funcs = lang::LowerVec("A_plus_B", poly::CreateStages({A, C}), {A, C}, {}, {}, nullptr, target, true); + + std::string a_plus_const_funcs_1_str = R"ROC(function A_plus_const (_A, _B) +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + serial for (j, 0, 32) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = (A[i0, i1] + 2.00000000f) + } + } + } + } +})ROC"; + + std::string a_plus_const_funcs_2_str = R"ROC(function A_plus_const (_A, _B) +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + serial for (j, 0, 32) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = (A[i0, i1] + 2.00000000f) + } + } + } + } +})ROC"; + + std::string a_plus_b_funcs_str = R"ROC(function A_plus_B (_A, _C) +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 32) + { + serial for (j, 0, 32) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = (A[i0, i1] + 2.00000000f) + } + } + } + serial for (i, 0, 32) + { + serial for (j, 0, 32) + { + ScheduleBlock(C) + { + i0_0, i1_0 = axis.bind(i, j) + C[i0_0, i1_0] = (A[i0_0, i1_0] + B[i0_0, i1_0]) + } + } + } + } + } +})ROC"; + + ASSERT_EQ(a_plus_const_funcs_1.size(), 1); + EXPECT_EQ(a_plus_const_funcs_1_str, utils::GetStreamCnt(a_plus_const_funcs_1.front())); + ASSERT_EQ(a_plus_const_funcs_2.size(), 1); + EXPECT_EQ(a_plus_const_funcs_2_str, utils::GetStreamCnt(a_plus_const_funcs_2.front())); + ASSERT_EQ(a_plus_b_funcs.size(), 1); + EXPECT_EQ(a_plus_b_funcs_str, utils::GetStreamCnt(a_plus_b_funcs.front())); + + SearchState a_plus_const_state1(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_1.front()->body}))); + SearchState a_plus_const_state2(ir::IRSchedule(ir::ModuleExpr({a_plus_const_funcs_2.front()->body}))); + SearchState a_plus_b_state(ir::IRSchedule(ir::ModuleExpr({a_plus_b_funcs.front()->body}))); + + SearchStateHash hash_functor; + SearchStateEqual equal_functor; + ASSERT_EQ(hash_functor(a_plus_const_state1), hash_functor(a_plus_const_state2)); + ASSERT_TRUE(equal_functor(a_plus_const_state1, a_plus_const_state2)); + ASSERT_NE(hash_functor(a_plus_const_state1), hash_functor(a_plus_b_state)); + ASSERT_FALSE(equal_functor(a_plus_const_state1, a_plus_b_state)); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt b/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt new file mode 100644 index 0000000000000..a31e01c801a57 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/CMakeLists.txt @@ -0,0 +1,7 @@ +add_subdirectory(mutate_rule) + +core_gather_headers() + +gather_srcs(cinnapi_src SRCS evolutionary_search.cc) + +cc_test(test_evolutionary_search SRCS evolutionary_search_test.cc DEPS cinncore test_program_builder) diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 0000000000000..c938718ad06af --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,302 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_strategy/evolutionary_search.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/auto_schedule/database/database.h" +#include "cinn/auto_schedule/post_schedule_rule/cooperative_process.h" +#include "cinn/auto_schedule/search_space/search_space.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/tuning.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/utils/multi_threading.h" +#include "cinn/utils/sized_multi_set.h" +#include "cinn/utils/string.h" + +DECLARE_bool(auto_schedule_use_cost_model); + +namespace cinn { +namespace auto_schedule { + +EvolutionarySearch::EvolutionarySearch(const TuneTask& tune_task, + const ExprCostModel& cost_model, + Database* database, + utils::LinearRandomEngine::StateType rand_seed, + const std::vector>& mutate_rules) + : tune_task_(tune_task), + cost_model_(cost_model), + database_(database), + rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)), + mutators_(mutate_rules) { + search_space_ = std::make_unique(tune_task, utils::ForkRandomState(&rand_seed_)); + if (mutators_.empty()) { + mutators_.push_back(std::make_tuple("mutate_tile_size", 1.0)); + } + double accum_weight = 0.0; + for (const auto& mutator : mutators_) { + if (std::get<1>(mutator) > 0) { + accum_weight += std::get<1>(mutator); + weighted_mutators_.insert(std::make_pair(accum_weight, MutateRule::Make(std::get<0>(mutator)))); + } + } + + post_schedule_rules_.emplace_back(new CooperativeProcess); +} + +EvolutionarySearch::~EvolutionarySearch() {} + +SearchState EvolutionarySearch::SearchModuleExpr(const TuningOptions& options) { + return SearchModuleExprBests(options)[0]; +} + +std::vector EvolutionarySearch::SearchModuleExprBests(const TuningOptions& options) { + VLOG(4) << "start SearchModuleExprBests with initial statistics: visited_candidates size=" + << visited_candidates_.size(); + std::vector init_population; + std::vector topk_from_database = GetTopKCandidatesFromDatabase(options.evolution_pick_database_topk); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::GetTopKCandidatesFromDatabase", topk_from_database, /*verbose=*/VLOG_IS_ON(5)); + int init_num = options.evolution_init_population_num - topk_from_database.size(); + + std::vector init_sketch = InitSketch(init_num, "rule_prune"); + VLOG(4) << JoinStatesDebugString("EvolutionarySearch::InitSketch", init_sketch, /*verbose=*/VLOG_IS_ON(5)); + + init_population.insert(init_population.end(), topk_from_database.begin(), topk_from_database.end()); + init_population.insert(init_population.end(), init_sketch.begin(), init_sketch.end()); + + std::vector picked_bests = + Evolve(init_population, options.evolution_cross_over_num, options.num_samples_per_iteration); + VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve", picked_bests, /*verbose=*/VLOG_IS_ON(5)); + return picked_bests; +} + +std::vector EvolutionarySearch::SearchModuleExprEpsGreedy(const TuningOptions& options) { + std::vector picked_bests = SearchModuleExprBests(options); + int random_num = options.evolution_init_population_num - options.evolution_pick_database_topk; + auto results = PickNextGenerationEpsGreedy(picked_bests, + InitSketch(random_num, "random_prune"), + options.num_samples_per_iteration, + options.evolution_eps_greedy); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::PickNextGenerationEpsGreedy", results, /*verbose=*/VLOG_IS_ON(5)); + return results; +} + +std::vector EvolutionarySearch::GetTopKCandidatesFromDatabase(int topk) { + std::vector results; + const auto& task_key = tune_task_.serialized_key; + auto records = database_->GetTopK(task_key, topk); + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + for (auto&& record : records) { + ir::IRSchedule ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), + utils::ForkRandomState(&rand_seed_)); + ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch); + results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost)); + } + return results; +} + +void ApplyPostScheduleRules(ir::IRSchedule* schedule, + const std::vector>& post_schedule_rules) { + schedule->TagPostSchedule(); + for (const auto& post_rule : post_schedule_rules) { + post_rule->Apply(schedule); + } +} + +std::vector EvolutionarySearch::InitSketch(int num, const std::string& strategy) { + VLOG(4) << "InitSketch with num:" << num << ", strategy: " << strategy; + std::vector states = search_space_->GenerateSketches(num, strategy); + auto post_schedule_fn = [this, &states](int index) { + ApplyPostScheduleRules(&states[index]->ir_schedule, post_schedule_rules_); + }; + utils::parallel_run(post_schedule_fn, utils::SequenceDispatcher(0, states.size()), states.size()); + + return states; +} + +SearchState EvolutionarySearch::CrossOver(const SearchState& state1, const SearchState& state2) { + // TODO(CtfGo): tracing CrossOver with IRSchedule + std::vector cross_over_exprs; + std::vector father_exprs = state1->ir_schedule.GetModule().GetExprs(); + std::vector mother_exprs = state2->ir_schedule.GetModule().GetExprs(); + + CHECK_EQ(father_exprs.size(), mother_exprs.size()) + << "CrossOver ModuleExpr in EvolutionarySearch must have same number of AST"; + + for (size_t i = 0; i < father_exprs.size(); ++i) { + if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) { + cross_over_exprs.push_back(optim::IRCopy(father_exprs[i])); + } else { + cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i])); + } + } + auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs), utils::ForkRandomState(&rand_seed_))); + if (FLAGS_auto_schedule_use_cost_model) { + res->predicted_cost = cost_model_.Predict(res->ir_schedule.GetModule(), tune_task_.target); + } + VLOG(5) << JoinStatesDebugString("EvolutionarySearch::CrossOver", {state1, state2, res}, /*verbose=*/VLOG_IS_ON(6)); + return res; +} + +SearchState EvolutionarySearch::Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed) { + CHECK_GT(weighted_mutators_.size(), 0) << "There is no mutate rule can be applied."; + double accu_weight = (weighted_mutators_.rbegin())->first; + CHECK_GT(accu_weight, 0) << "The accumulate weight must be greater than 0."; + // sample a mutate rule + double sample_weight = utils::SampleUniformDouble(0, accu_weight, rand_seed); + auto sampled_iter = weighted_mutators_.upper_bound(sample_weight); + MutateRule* mutator = sampled_iter->second.get(); + CHECK(mutator) << "mutator not defined"; + // apply mutation on the trace of SearchState + auto trace = state->ir_schedule.GetTraceDesc(); + auto new_trace = mutator->Apply(trace, rand_seed); + // replay the mutated trace on original ModuleExpr to generate a new ir_schedule + const auto& task_key = tune_task_.serialized_key; + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + ir::IRSchedule new_ir_sch(optim::IRCopy(task_registry->Get(task_key)->module_expr), + utils::ForkRandomState(rand_seed)); + new_trace.Replay(&new_ir_sch, true); + ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_); + auto res = SearchState(std::move(new_ir_sch)); + + VLOG(5) << JoinStatesDebugString("EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6)); + return res; +} + +std::vector EvolutionarySearch::Evolve(const std::vector& population, + int cross_over_num, + int ret_num) { + VLOG(4) << utils::StringFormat( + "Evolve with population size=%lu,cross_over_num:%lu,ret_num:%lu", population.size(), cross_over_num, ret_num); + int generation_num = population.size(); + if (generation_num == 0) { + return std::vector(); + } + // init evolution + std::vector evolution(population); + for (SearchState& search_state : evolution) { + if (search_state->predicted_cost == SearchState::NOT_INIT_COST && FLAGS_auto_schedule_use_cost_model) { + search_state->predicted_cost = cost_model_.Predict(search_state->ir_schedule.GetModule(), tune_task_.target); + } + } + VLOG(4) << JoinStatesDebugString("EvolutionarySearch::Evolve: Init evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); + // cross over + for (int i = 0; i < cross_over_num; ++i) { + int first_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); + int second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); + while (first_rand_idx == second_rand_idx) { + second_rand_idx = utils::SampleUniformInt(0, generation_num, &rand_seed_); + } + evolution.push_back(CrossOver(population[first_rand_idx], population[second_rand_idx])); + } + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::Evolve: after CrossOver evolution:", evolution, /*verbose=*/VLOG_IS_ON(5)); + // mutate + std::vector mutated_individuals(evolution.size()); + std::vector rand_seeds(evolution.size()); + for (int i = 0; i < rand_seeds.size(); ++i) { + rand_seeds[i] = utils::ForkRandomState(&rand_seed_); + } + auto mutate_fn = [this, &evolution, &mutated_individuals, &rand_seeds](int index) { + mutated_individuals[index] = Mutate(evolution[index], &rand_seeds[index]); + }; + utils::parallel_run(mutate_fn, utils::SequenceDispatcher(0, evolution.size()), evolution.size()); + if (FLAGS_auto_schedule_use_cost_model) { + for (size_t i = 0; i < mutated_individuals.size(); ++i) { + mutated_individuals[i]->predicted_cost = + cost_model_.Predict(mutated_individuals[i]->ir_schedule.GetModule(), tune_task_.target); + } + } + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::Evolve: mutated individuals:", mutated_individuals, /*verbose=*/VLOG_IS_ON(5)); + // select top ret_num with predicted cost + utils::SizedMultiSet evolution_with_cost(ret_num); + for (size_t i = 0; i < evolution.size(); ++i) { + evolution_with_cost.Push(evolution[i]); + } + for (size_t i = 0; i < mutated_individuals.size(); ++i) { + evolution_with_cost.Push(mutated_individuals[i]); + } + auto selected_individuals = evolution_with_cost.ReturnAsContainer>(); + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::Evolve: selected individuals:", selected_individuals, /*verbose=*/VLOG_IS_ON(5)); + + return selected_individuals; +} + +std::vector EvolutionarySearch::PickNextGenerationEpsGreedy(const std::vector& picked_bests, + const std::vector& random_init, + int num, + float eps_greedy) { + int num_rands = num * eps_greedy; + int num_bests = num - num_rands; + + std::vector result; + SearchState selected; + int deduplicated_cnt = 0; + int best_idx = 0; + int rand_idx = 0; + while (result.size() < num) { + if (result.size() < num_bests && best_idx < picked_bests.size()) { + selected = picked_bests[best_idx]; + ++best_idx; + } else if (rand_idx < random_init.size()) { + selected = random_init[rand_idx]; + ++rand_idx; + } else if (best_idx < picked_bests.size()) { + selected = picked_bests[best_idx]; + ++best_idx; + } else { + break; + } + + if (!visited_candidates_.count(selected)) { // deduplicate + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::PickNextGenerationEpsGreedy-Selected", {selected}, /*verbose=*/VLOG_IS_ON(5)); + visited_candidates_.insert(selected); + result.push_back(selected); + } else { + ++deduplicated_cnt; + VLOG(4) << JoinStatesDebugString( + "EvolutionarySearch::PickNextGenerationEpsGreedy-Deduplicated", {selected}, /*verbose=*/VLOG_IS_ON(5)); + } + } + + VLOG(4) << utils::StringFormat( + "PickNextGenerationEpsGreedy: picked_bests size=%lu,random_init size=%lu,num=%d," + "eps_greedy=%f,deduplicated_cnt=%d,result size=%lu", + picked_bests.size(), + random_init.size(), + num, + eps_greedy, + deduplicated_cnt, + result.size()); + return result; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h new file mode 100644 index 0000000000000..40e5bb9f7e889 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search.h @@ -0,0 +1,146 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/database/database.h" +#include "cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h" +#include "cinn/auto_schedule/search_space/search_space.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/tuning.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * Class implement the evolutionary search on ModuleExpr search space. + */ +class EvolutionarySearch { + public: + /** + * constructor with TuneTask. + * + * @param tune_task: the TuneTask this class works on. This class doesn't + * take ownership of the pointer. + */ + EvolutionarySearch(const TuneTask& tune_task, + const ExprCostModel& cost_model, + Database* database, + utils::LinearRandomEngine::StateType rand_seed = -1, + const std::vector>& mutate_rules = {}); + + /** + * Destructor + */ + ~EvolutionarySearch(); + + /** + * Run the evolutionary search for one iteration. + * + * @return SearchState containing the best ir::ModuleExpr searched in this iteration + */ + SearchState SearchModuleExpr(const TuningOptions& options); + + /** + * Run the evolutionary search for one iteration. + * + * @return SearchState(s) containing best ir::ModuleExpr(s) searched in this iteration + */ + std::vector SearchModuleExprBests(const TuningOptions& options); + + /** + * Run the evolutionary search for one iteration, but since evolutionary + * search with cost model may not be accurate, this method picks + * "eps * total_return_size" random samples along with those best + * ir::ModuleExpr's searched in this iteration. + * + * @return SearchSpace containing those best ir::ModuleExpr's searched + * in this iteration and some random samples. There are + * "eps * total_return_size" random samples and + * "(1 - eps) * total_return_size" best searched samples. + */ + std::vector SearchModuleExprEpsGreedy(const TuningOptions& options); + +#ifdef CINN_WITH_TEST + /** + * Method only be called during testing. It is used to set mock search + * space. + * + * @param search_space: the mock search space, note that EvolutionarySearch + * takes the ownership. + */ + void SetSearchSpace(SearchSpace* search_space) { search_space_.reset(search_space); } + + // Method only be called during testing, it is a wrapper of private method InitSketch(). + std::vector TestInitSketch(int num, const std::string& strategy) { return InitSketch(num, strategy); } + + // Method only be called during testing, it is a wrapper of private method Evolve(). + std::vector TestEvolve(const std::vector& population, int cross_over_num, int ret_num) { + return Evolve(population, cross_over_num, ret_num); + } +#endif + + private: + std::vector GetTopKCandidatesFromDatabase(int topk); + + /** + * \brief Generate sketch as initial population of evolutionary search. + * @param num The number of sketches to generate. + * @param strategy The strategy to generate sketches, + * Current optional strategies are "rule_prune" or "random_prune" or "random". + * - "rule_prune": will use rules to prune and generate sketches as efficiently as possible. + * - "random_prune": will use the new interface ApplySketchRules() to simulate the random generation of sketches, + * and supports the function of a rule returning multiple SearchStates and random pruning by probability. + * - "random": will randomly select a block and a rule to apply and repeat this step several times, + * however, each rule can only be used on one SearchState at most once. + * @return Generated sketches. + */ + std::vector InitSketch(int num, const std::string& strategy); + + SearchState Mutate(const SearchState& state, utils::LinearRandomEngine::StateType* rand_seed); + + SearchState CrossOver(const SearchState& state1, const SearchState& state2); + + std::vector Evolve(const std::vector& population, int cross_over_num, int ret_num); + + std::vector PickNextGenerationEpsGreedy(const std::vector& population, + const std::vector& random_init, + int num, + float eps_greedy); + + private: + std::unique_ptr search_space_; + const TuneTask& tune_task_; + const ExprCostModel& cost_model_; // not owned + Database* database_; // not owned + // used to duplicate states with the same structural IR + std::unordered_set visited_candidates_; + // mutate rule names and their weights + std::vector> mutators_; + // mutate rules, the key is the accumulate weight of each mutate rule + std::map> weighted_mutators_; + // schedule rules used after mutation + std::vector> post_schedule_rules_; + utils::LinearRandomEngine::StateType rand_seed_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc new file mode 100644 index 0000000000000..4f6764b41f65a --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/evolutionary_search_test.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_strategy/evolutionary_search.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/database/database.h" +#include "cinn/auto_schedule/search_space/search_space.h" +#include "cinn/auto_schedule/search_space/search_state.h" +#include "cinn/auto_schedule/task/task_creator.h" +#include "cinn/auto_schedule/task/task_registry.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/tuning.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "tests/program_builder.h" + +namespace cinn { +namespace auto_schedule { + +std::vector CreateTasks(const frontend::Program& program, const Target& target) { + auto graph = std::make_shared(program, target); + TaskCreator task_creator; + auto tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs>("infershape"); + auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + for (auto i = 0; i < tasks.size(); ++i) { + tasks[i].Initialize(shape_dict, dtype_dict, op_lowerer.get()); + task_registry->Regist(tasks[i].serialized_key, ir::ModuleExpr(tasks[i].GetLoweredFuncBodyExprs())); + } + return tasks; +} + +/** + * A mock search space is only used for test. It creates integer ir::Expr from + * 0, -1, -2, ... and set the cost value same as the integer value. + * + * So evolutionary search should be able to find the minimal ModuleExpr with + * smallest ir::Expr. This file tests it. + */ +class MockSearchSpace : public SearchSpace { + public: + MockSearchSpace(const TuneTask& tune_task) : SearchSpace(tune_task) {} + + int GetMinExprValue() const { return min_expr_value_; } + + int GetModuleExprSize() const { return module_expr_size_; } + + std::vector GenerateSketches(int num, const std::string& strategy) override { + std::vector ret; + for (int i = 0; i < num; ++i) { + std::vector exprs; + for (int j = 0; j < module_expr_size_; ++j) { + exprs.push_back(ir::Expr(-i)); + } + min_expr_value_ = -i; + ret.push_back(SearchState(ir::IRSchedule(ir::ModuleExpr(exprs)))); + } + return ret; + } + + private: + int module_expr_size_ = 10; + int min_expr_value_ = 0; +}; + +class MockCostModel : public ExprCostModel { + float Predict(const ir::ModuleExpr& sample, const common::Target& target) const override { + float cost = 0.0f; + std::vector exprs = sample.GetExprs(); + for (const ir::Expr& expr : exprs) { + if (expr.as_int32()) { + cost += static_cast((expr.as_int32())); + } + } + return cost; + } +}; + +TEST(EvolutionarySearch, GetOneBest) { + TuneTask mock_tune_task; + mock_tune_task.serialized_key = "mock_task"; + mock_tune_task.target = common::DefaultTarget(); + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); + MockCostModel cost_model; + TuningOptions options; + Database db(2); + EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db); + + MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); + // Ownership is transferred so don't delete mock_search_space + evolutionary_search.SetSearchSpace(mock_search_space); + SearchState best_state = evolutionary_search.SearchModuleExpr(options); + std::vector exprs = best_state->ir_schedule.GetModule().GetExprs(); + EXPECT_GE(exprs.size(), 1UL); + for (const ir::Expr& e : exprs) { + EXPECT_EQ(e.as_int32(), mock_search_space->GetMinExprValue()); + } +} + +TEST(EvolutionarySearch, GetEpsGreedy) { + TuneTask mock_tune_task; + mock_tune_task.serialized_key = "mock_task"; + mock_tune_task.target = common::DefaultTarget(); + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + task_registry->Regist(mock_tune_task.serialized_key, ir::ModuleExpr({ir::Expr(0)})); + ExprCostModel cost_model; + TuningOptions options; + Database db(2); + EvolutionarySearch evolutionary_search(mock_tune_task, cost_model, &db); + + MockSearchSpace* mock_search_space = new MockSearchSpace(mock_tune_task); + // Ownership is transferred so don't delete mock_search_space + evolutionary_search.SetSearchSpace(mock_search_space); + std::vector search_states = evolutionary_search.SearchModuleExprEpsGreedy(options); + + EXPECT_GE(search_states.size(), 1UL); + size_t expr_size = static_cast(mock_search_space->GetModuleExprSize()); + for (const SearchState& state : search_states) { + EXPECT_EQ(state->ir_schedule.GetModule().GetExprs().size(), expr_size); + } +} + +TEST(EvolutionarySearch, Evolve) { + auto target = common::DefaultNVGPUTarget(); + auto tasks = CreateTasks(tests::OpBuilder("matmul").Build({{"X", {32, 32}}, {"Y", {32, 32}}}), target); + CHECK_EQ(tasks.size(), 1); + ExprCostModel cost_model; + std::vector cost_model_samples(1); + std::vector cost_model_labels(1); + for (size_t i = 0; i < 2; ++i) { + ir::ModuleExpr me({ir::Expr(tasks[0].lowered_funcs[0])}); + cost_model_samples[0] = &me; + cost_model_labels[0] = i + 10; + cost_model.Update(cost_model_samples, cost_model_labels, target); + } + + Database db(2); + TuningOptions options; + options.evolution_pick_database_topk = 0; + + EvolutionarySearch evolutionary_search(tasks[0], cost_model, &db); + + int num_population = 10; + std::vector init_sketch = evolutionary_search.TestInitSketch(num_population, "rule_prune"); + for (int i = 0; i < num_population; ++i) { + ir::ModuleExpr me(init_sketch[i]->ir_schedule.GetModule()); + cost_model_samples[0] = &me; + cost_model_labels[0] = i; + cost_model.Update(cost_model_samples, cost_model_labels, target); + } + VLOG(6) << "init sketch costs:"; + for (auto s : init_sketch) { + VLOG(6) << "cost = " << s->predicted_cost; + } + std::vector*population_pre_ptr = &init_sketch, *population_next_ptr; + std::vector population; + for (int i = 0; i < 10; ++i) { + population = evolutionary_search.TestEvolve(*population_pre_ptr, /*cross_over_num*/ 0, /*ret_num*/ 10); + population_next_ptr = &population; + VLOG(6) << "population[" << i + 1 << "] costs:"; + double total_cost_pre = 0.0, total_cost_next = 0.0; + for (auto s : *population_pre_ptr) { + total_cost_pre += s->predicted_cost; + } + for (auto s : *population_next_ptr) { + total_cost_next += s->predicted_cost; + VLOG(6) << "cost = " << s->predicted_cost; + } + VLOG(6) << "total_cost_next = " << total_cost_next; + CHECK_LE(total_cost_next, total_cost_pre); + std::swap(population_pre_ptr, population_next_ptr); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt new file mode 100644 index 0000000000000..308f9a91feea5 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/CMakeLists.txt @@ -0,0 +1,8 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + mutate_rule.cc + mutate_tile_size.cc + ) + +cc_test(test_mutate_tile_size SRCS mutate_tile_size_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc new file mode 100644 index 0000000000000..8e07e0d572788 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h" + +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h" + +namespace cinn { +namespace auto_schedule { + +std::unique_ptr MutateRule::Make(const std::string& name) { + if (name == "mutate_tile_size") { + return std::make_unique(); + } else { + LOG(FATAL) << "MutateRule " << name << " is not supported."; + } + return nullptr; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h new file mode 100644 index 0000000000000..b650a9c746763 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h @@ -0,0 +1,48 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/ir/schedule_desc.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace auto_schedule { + +/** + * Base class for rules of mutate, + * is used for mutating the trace(ScheduleDesc) to explore the search space. + */ +class MutateRule { + public: + MutateRule() = default; + + /** + * @brief Apply the mutate rule to the given trace. + * @param trace The given trace for mutation. + * @param rand_seed The random seed for mutation. + * @return The mutated trace. + */ + virtual ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) = 0; + + /** + * @brief Create a MutateRule with name. + * @param name The name of mutate rule, consisting of lowercase letters and underscores + * @return The created MutateRule. + */ + static std::unique_ptr Make(const std::string& name); +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc new file mode 100644 index 0000000000000..bc59bf668198d --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h" + +namespace cinn { +namespace auto_schedule { + +using ::cinn::ir::ScheduleDesc; +using ::cinn::utils::LinearRandomEngine; + +using SampledTile = std::tuple, int>; + +static std::vector Factorize(int n) { + std::vector res; + for (int i = 1; i * i <= n; ++i) { + if (n % i == 0) { + res.push_back(i); + if (i * i != n) { + res.push_back(n / i); + } + } + } + std::sort(res.begin(), res.end()); + return res; +} + +std::vector FindSampledTiles(const ScheduleDesc& trace) { + std::vector tiles; + int step_idx = 0; + for (auto&& step : trace.Steps()) { + if (step.type == "TagPostSchedule") { + break; + } + if (step.type == "SamplePerfectTile") { + std::vector tile_factors = absl::get>(step.attrs.at("decision")); + CHECK(tile_factors.size() >= 2) << "factors size must be greater equal than 2, which is " << tile_factors.size(); + tiles.push_back(std::make_tuple(step, tile_factors, step_idx)); + } + ++step_idx; + } + + return tiles; +} + +ScheduleDesc DoMutateTileSize(const ScheduleDesc& trace, + const SampledTile& tile, + LinearRandomEngine::StateType* rand_seed) { + ScheduleDesc::Step step = std::get<0>(tile); + std::vector tile_factors = std::get<1>(tile); + int split_size = tile_factors.size(); + // Step 1. Choose 2 loops with index: 'loop_x' and 'loop_y' + int loop_x, loop_y; + + bool all_one_factors = true; + for (int t : tile_factors) { + if (t != 1) { + all_one_factors = false; + break; + } + } + if (all_one_factors) { + VLOG(6) << "Factors are all 1, unable to mutate, return the original trace"; + return trace; + } + + while (true) { + VLOG(6) << "while (true) loop in DoMutateTileSize"; + loop_x = utils::SampleUniformInt(0, split_size, rand_seed); + if (tile_factors.at(loop_x) <= 1) { + continue; + } + loop_y = utils::SampleUniformInt(0, split_size - 1, rand_seed); + if (loop_y >= loop_x) { + ++loop_y; + } + std::vector optional_factors = Factorize(tile_factors.at(loop_x)); + // Step 2. Choose the divisor for mutate. + int divisor; + if (loop_y == split_size - 1) { + int max_innermost_factor = absl::get(step.attrs.at("max_innermost_factor")); + int max_optional_factor_idx = optional_factors.size() - 1; + for (; max_optional_factor_idx > 0; --max_optional_factor_idx) { + if (optional_factors.at(max_optional_factor_idx) * tile_factors.at(loop_y) <= max_innermost_factor) { + break; + } + } + if (max_optional_factor_idx == 0) { + if (split_size <= 2) { + VLOG(6) << "Unable to mutate, return the original trace"; + return trace; + } + continue; + } + divisor = optional_factors.at(utils::SampleUniformInt(1, max_optional_factor_idx + 1, rand_seed)); + } else { + divisor = optional_factors.at(utils::SampleUniformInt(1, optional_factors.size(), rand_seed)); + } + // Step 3. Determine the new tile value + VLOG(6) << "DoMutateTileSize: divisor = " << divisor << ", before mutate: \n" + << "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y + << "] = " << tile_factors[loop_y]; + tile_factors[loop_x] /= divisor; + tile_factors[loop_y] *= divisor; + VLOG(6) << "after mutate: \n" + << "factors[" << loop_x << "] = " << tile_factors[loop_x] << ", factors[" << loop_y + << "] = " << tile_factors[loop_y]; + // Step 4. Create a new step with new tile values and return the new trace + int step_idx = std::get<2>(tile); + return trace.ForkAndUpdate(step_idx, tile_factors, true); + } +} + +ScheduleDesc MutateTileSize::Apply(const ScheduleDesc& trace, LinearRandomEngine::StateType* rand_seed) { + VLOG(6) << "Start applying MutateTileSize, old trace: \n" << trace.DebugString(); + std::vector sample_tile_steps; + std::vector> sample_tile_data; + + auto sampled_tiles = FindSampledTiles(trace); + if (sampled_tiles.size() == 0) { + VLOG(6) << "MutateTileSize failed, try other mutate rules."; + return trace; + } + int sample_step_idx = utils::SampleUniformInt(0, sampled_tiles.size(), rand_seed); + auto new_trace = DoMutateTileSize(trace, sampled_tiles.at(sample_step_idx), rand_seed); + VLOG(6) << "End applying MutateTileSize, new trace: \n" << new_trace.DebugString(); + return new_trace; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h new file mode 100644 index 0000000000000..2313a38577c38 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.h" + +namespace cinn { +namespace auto_schedule { + +/** + * The rule to mutate tile size, witch will modify the factors of the Split primitive. + */ +class MutateTileSize : public MutateRule { + public: + MutateTileSize() = default; + + ir::ScheduleDesc Apply(const ir::ScheduleDesc& trace, utils::LinearRandomEngine::StateType* rand_seed) override; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc new file mode 100644 index 0000000000000..c8b4ce0a27ae6 --- /dev/null +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/search_strategy/mutate_rule/mutate_tile_size.h" + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir_schedule.h" + +namespace cinn { +namespace auto_schedule { + +TEST(MutateTileSize, Basic) { + srand(0); + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + + const int kSize = 32; + Expr M(kSize); + Expr N(kSize); + Expr K(kSize); + + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + Var k(K.as_int32(), "reduce_axis_k"); + ir::Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + poly::StageMap stages = CreateStages({A, B, C}); + std::vector funcs = + lang::LowerVec("TestMutateTileSize_Basic", stages, {A, B, C}, {}, {}, nullptr, target, true); + + ir::Expr ast_expr = funcs[0]->body; + VLOG(6) << "Original Expr: "; + VLOG(6) << ast_expr; + ir::ModuleExpr module_expr({ast_expr}); + // We need to fix the seed as a constant to ensure that the result can be repeated. + utils::LinearRandomEngine::StateType rand_seed = 123; + ir::IRSchedule ir_schedule(module_expr, rand_seed); + ir::IRSchedule new_ir_schedule(ir_schedule); + + // apply schedule + auto loops = ir_schedule.GetLoops("C"); + auto factors = ir_schedule.SamplePerfectTile(loops[0], 2, kSize); + auto splited = ir_schedule.Split(loops[0], factors); + + // apply mutate + MutateTileSize mutator; + ir::ScheduleDesc sch_desc = mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed); + sch_desc.Replay(&new_ir_schedule, true); + VLOG(6) << "Expr before mutate tile size: \n" << ir_schedule.GetModule().GetExprs()[0]; + VLOG(6) << "Expr after mutate tile size: \n" << new_ir_schedule.GetModule().GetExprs()[0]; + + std::string target_new_ir = R"ROC({ + ScheduleBlock(root) + { + serial for (i_1, 0, 2) + { + serial for (i_2, 0, 16) + { + serial for (j, 0, 32) + { + ScheduleBlock(C__reduce_init) + { + i0, i1 = axis.bind(((16 * i_1) + i_2), j) + C__reduce_init[i0, i1] = 0.00000000f + } + serial for (reduce_axis_k, 0, 32) + { + ScheduleBlock(C) + { + i0_0, i1_0, i2 = axis.bind(((16 * i_1) + i_2), j, reduce_axis_k) + C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0])) + } + } + } + } + } + } +})ROC"; + + auto get_ir_str = [](const ir::IRSchedule* ir_sch) -> std::string { + std::vector exprs = ir_sch->GetModule().GetExprs(); + EXPECT_EQ(exprs.size(), 1UL); + std::stringstream ss; + ss << exprs[0]; + return ss.str(); + }; + ASSERT_EQ(get_ir_str(&new_ir_schedule), target_new_ir); + + std::vector last_tile_factors = {2, 16}; + for (int i = 0; i < 10; ++i) { + sch_desc = mutator.Apply(sch_desc, &rand_seed); + for (auto&& step : sch_desc.Steps()) { + if (step.type == "SamplePerfectTile") { + std::vector tile_factors = absl::get>(step.attrs.at("decision")); + ASSERT_EQ(tile_factors.size(), last_tile_factors.size()); + ASSERT_NE(tile_factors[0], last_tile_factors[0]); + ASSERT_NE(tile_factors[1], last_tile_factors[1]); + ASSERT_EQ(tile_factors[0] * tile_factors[1], kSize); + last_tile_factors = tile_factors; + } + } + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/CMakeLists.txt b/paddle/cinn/auto_schedule/task/CMakeLists.txt new file mode 100644 index 0000000000000..f3dc34dad4c86 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/CMakeLists.txt @@ -0,0 +1,12 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + task_creator.cc + task_optimizer.cc + tune_task.cc + ) +gather_srcs(cinnapi_src SRCS task_creator.cc task_optimizer.cc) + +cc_test(test_task_creator SRCS task_creator_test.cc DEPS cinncore) +cc_test(test_tune_task SRCS tune_task_test.cc DEPS cinncore) +cc_test(test_task_registry SRCS task_registry_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/task/task_creator.cc b/paddle/cinn/auto_schedule/task/task_creator.cc new file mode 100644 index 0000000000000..6d62ec2a7278d --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_creator.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/task_creator.h" + +#include + +#include +#include +#include + +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/pass.h" + +namespace cinn { +namespace auto_schedule { + +using ::cinn::common::GraphEdge; +using ::cinn::common::GraphNode; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::Node; +using ::cinn::hlir::framework::NodeData; + +std::vector TaskCreator::CreateTuneTaskOpLevel(Graph* graph) { + std::vector ret_tasks; + + const std::vector>* groups = &graph->fusion_groups; + std::vector> non_fused_groups; + // The input graph doesn't run Op Fusion + if (graph->fusion_groups.empty()) { + hlir::framework::ApplyPasses(graph, {"BuildNonFusedGroupsPass"}); + groups = &graph->fusion_groups; + } + VLOG(3) << "Graph groups size:" << groups->size(); + + for (const auto& sub_graph : *groups) { + ret_tasks.emplace_back(TuneTask()); + ret_tasks.back().subgraph = sub_graph; + ret_tasks.back().target = graph->target_; + } + return ret_tasks; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_creator.h b/paddle/cinn/auto_schedule/task/task_creator.h new file mode 100644 index 0000000000000..6dd600f54e340 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_creator.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/graph.h" + +namespace cinn { +namespace auto_schedule { + +/** + * Class to create auto tune task. + */ +class TaskCreator { + public: + std::vector CreateTuneTaskOpLevel(hlir::framework::Graph* graph); +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_creator_test.cc b/paddle/cinn/auto_schedule/task/task_creator_test.cc new file mode 100644 index 0000000000000..fe5638108e884 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_creator_test.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/task_creator.h" + +#include + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/node.h" + +namespace cinn { +namespace auto_schedule { + +using ::cinn::frontend::NetBuilder; +using ::cinn::frontend::Program; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::Node; + +Program CreateAddProgram() { + constexpr int M = 32; + constexpr int N = 24; + + NetBuilder builder("net_builder"); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); + auto program = builder.Build(); + + return program; +} + +TEST(TaskCreator, Basic) { +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + Program prog = CreateAddProgram(); + auto graph = std::make_shared(prog, target); + + TaskCreator task_creator; + std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + + ASSERT_EQ(tasks.size(), 2UL); + for (TuneTask& task : tasks) { + std::shared_ptr subgraph = task.subgraph; + ASSERT_EQ(subgraph->CollectNodes().size(), 1UL); + ASSERT_EQ(subgraph->nodes[0]->op()->name, "elementwise_add"); + } +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_optimizer.cc b/paddle/cinn/auto_schedule/task/task_optimizer.cc new file mode 100644 index 0000000000000..b4afd2fa0bd4b --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_optimizer.cc @@ -0,0 +1,407 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/task_optimizer.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/analysis/analyze_ir.h" +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/measure/measure.h" +#include "cinn/auto_schedule/search_strategy/evolutionary_search.h" +#include "cinn/common/target.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/op/external_api_registry.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/transform_gpu_forloop.h" +#include "cinn/runtime/flags.h" +#include "cinn/utils/string.h" +#ifdef CINN_WITH_CUDA +#include + +#include "cinn/backends/cuda_util.h" +#endif + +DECLARE_bool(auto_schedule_use_cost_model); + +namespace cinn { +namespace auto_schedule { + +using cinn::hlir::op::ExternalApiRegistry; + +// *** forward declarations of auxiliary functions to be used in this file only *** +// update a scheduled function with several post-processors +ir::LoweredFunc FuncWithUpdatedBody(const common::Target& target, const ir::LoweredFunc& old_func, ir::Expr& body); +// check whether a scheduled lowered function is valid +bool PruneInvalid(const ir::LoweredFunc& lowered_func, const common::Target& target); +// exclude some special tasks +bool IsForbiddenToTune(const TuneTask* task); +// tell whether the task has been wrapped by custom_call in TransToCustomCallPass +bool IsWrappedByCustomCall(const TuneTask* task); +// tell whether the task has registered external api +bool HasExternalApi(const TuneTask* task); + +TaskOptimizer::TaskOptimizer(TuneTask* task, + ScheduleMeasurer* schedule_measurer, + Database* database, + utils::LinearRandomEngine::StateType rand_seed) + : task_(task), + schedule_measurer_(schedule_measurer), + database_(database), + cost_model_(), + rand_seed_(utils::LinearRandomEngine::NormalizeState(rand_seed)) {} + +FunctionGroup TaskOptimizer::Optimize(const TuningOptions& options) { + CHECK(task_->subgraph != nullptr) << "subgraph can't be empty"; + // task with forbidden or custom_call ops can't be tuned + if (IsForbiddenToTune(task_) || IsWrappedByCustomCall(task_)) { + return task_->op_lowerer->Lower(task_->subgraph); + } + // TODO(CtfGo): the input/output names of a Graph::Group will be changed in Lowering by OpLowerer currently, + // so we should revert them after following different lower methods, remove this hard code by fixing the + // decoupling between lowering and BuildInstructions + auto initial_input_names = task_->subgraph->input_names; + auto initial_output_names = task_->subgraph->output_names; + + std::vector candidates; + candidates.emplace_back(OptimizeByEvolution(options)); + candidates.emplace_back(OptimizeByManual(options.num_measure_trials > 0)); + if (HasExternalApi(task_)) { + candidates.emplace_back(OptimizeByExternal(options.num_measure_trials > 0)); + } + sort(candidates.begin(), candidates.end(), [](const auto& lhs, const auto& rhs) { return lhs.cost < rhs.cost; }); + auto&& best = candidates.front(); + VLOG(4) << "Total candidates=" << candidates.size() << ", the best from=" << best.from << ", cost=" << best.cost; + + // revert input/output names + task_->subgraph->input_names = initial_input_names; + task_->subgraph->output_names = initial_output_names; + return best.functions; +} + +TaskOptimizer::Result TaskOptimizer::OptimizeByManual(bool need_measured) { + static constexpr char* kManualMeasuredKeyPrefix = "@ManualMeasured:\n"; + TaskOptimizer::Result result("Manual"); + result.functions = task_->op_lowerer->Lower(task_->subgraph); + + // pack functions body + std::vector func_bodys; + for (const ir::LoweredFunc& func : result.functions) { + func_bodys.push_back(func->body); + } + + SearchState state(ir::IRSchedule(ir::ModuleExpr(std::move(func_bodys)))); + // the manual is regarded as the second best in default, so we set its cost 0.0 + result.cost = 0.0; + + // add the specific prefix in front of serialized_key to be store/load measured record for manual schedule + std::string measured_key = kManualMeasuredKeyPrefix + task_->serialized_key; + if (need_measured && database_->Count(measured_key) == 0) { + std::vector inputs(1); + inputs.back().task = task_; + inputs.back().lowered_funcs = result.functions; + VLOG(4) << "Measure manual schedule"; + std::vector measure_outputs = schedule_measurer_->Measure(inputs); + database_->AddRecord(TuningRecord(measured_key, state, measure_outputs[0].execution_cost)); + } + + auto measured_records = database_->LookUp(measured_key); + if (!measured_records.empty()) { // update result.cost by measured if exists + result.cost = measured_records[0].execution_cost; + } + return result; +} + +TaskOptimizer::Result TaskOptimizer::OptimizeByExternal(bool need_measured) { + static constexpr char* kExternalMeasuredKeyPrefix = "@ExternalMeasured:\n"; + TaskOptimizer::Result result("External"); + auto nodes = task_->subgraph->CollectNodes(); + auto* first_node = nodes.front(); + + // set the necessary field for lowering with external api + std::string original_op = first_node->op()->name; + first_node->attrs.attr_store["original_op"] = original_op; + first_node->attrs.op = hlir::framework::Operator::Get("custom_call"); + result.functions = task_->op_lowerer->Lower(task_->subgraph); + + // add the specific prefix in front of serialized_key to be store/load measured record for external api + result.cost = -1.0; // the external is regarded as the best in default, so we set its cost -1.0 + std::string measured_key = kExternalMeasuredKeyPrefix + task_->serialized_key; + if (need_measured && database_->Count(measured_key) == 0) { + std::vector inputs(1); + inputs.back().task = task_; + inputs.back().lowered_funcs = result.functions; + VLOG(4) << "Measure external api"; + std::vector measure_outputs = schedule_measurer_->Measure(inputs); + // the SearchState of external is invalid and will not be used, so we just put a temporary one + database_->AddRecord(TuningRecord(measured_key, SearchState(ir::IRSchedule()), measure_outputs[0].execution_cost)); + } + + auto measured_records = database_->LookUp(measured_key); + if (!measured_records.empty()) { // update result.cost by measured if exists + result.cost = measured_records[0].execution_cost; + } + return result; +} + +bool IsForbiddenToTune(const TuneTask* task) { + // TODO(CtfGo): some operators may change its linked edges in + // TransToCustomCallPass, like conv2d, we will skip these ops in auto-schedule + // because they can't revert original links for no schedule and manual schedule lowering. + static std::unordered_set links_changed_ops = {"conv2d"}; + auto nodes = task->subgraph->CollectNodes(); + auto&& op_name = nodes.front()->op()->name; + if (nodes.size() == 1 && links_changed_ops.count(op_name)) { + VLOG(5) << "Op:" << op_name << " is forbidden to call external_api"; + return true; + } + + return false; +} + +bool HasExternalApi(const TuneTask* task) { + auto nodes = task->subgraph->CollectNodes(); + auto* first_node = nodes.front(); + if (nodes.size() == 1 && ExternalApiRegistry::Global()->Has(first_node->op()->name, task->target)) { + return true; + } + return false; +} + +bool IsWrappedByCustomCall(const TuneTask* task) { + auto nodes = task->subgraph->CollectNodes(); + auto* first_node = nodes.front(); + if (nodes.size() == 1 && first_node->op()->name == "custom_call") { + CHECK(first_node->attrs.attr_store.count("original_op")) << "a custom_call op must store its original op name"; + std::string op_name = absl::get(first_node->attrs.attr_store.at("original_op")); + VLOG(5) << "Op:" << op_name << " was wrapped as custom_call"; + return true; + } + + return false; +} + +TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(const TuningOptions& options) { + CHECK_EQ(options.num_measure_trials % options.num_samples_per_iteration, 0) + << "TuningOptions.num_measure_trials % TuningOptions.num_samples_per_iteration must be 0."; + + VLOG(4) << "Optimizing TuneTask with num_measure_trials:" << options.num_measure_trials + << ", LoweredFunc before optimization is:"; + VLOG(4) << "lowered function size = " << task_->lowered_funcs.size(); + for (size_t i = 0; i < task_->lowered_funcs.size(); ++i) { + VLOG(4) << "lowered_funcs[" << i << "] detail:\n" << task_->lowered_funcs[i]; + } + + if (evolutionary_search_ == nullptr) { + // TODO(zhhsplendid): check whether the options is same as previous, + // if not, we should create new EvolutionarySearch + evolutionary_search_ = + std::make_unique(*task_, cost_model_, database_, utils::ForkRandomState(&rand_seed_)); + } + + TaskOptimizer::Result result("Evolution"); + auto& optimized_funcs = result.functions; + auto& best_cost = result.cost; + // use initial lowered function as default result + optimized_funcs = optim::IRCopy(task_->lowered_funcs); + if (options.num_measure_trials == 0) { // no need to measure and simply return the best searched + std::vector measure_candidates; + std::vector states = SearchOneRound(options, &measure_candidates); + if (!states.empty()) { + if (FLAGS_auto_schedule_use_cost_model) { + best_cost = cost_model_.Predict(states.front()->ir_schedule.GetModule(), task_->target); + } + optimized_funcs = measure_candidates[0].lowered_funcs; + } else { + LOG(WARNING) << "No valid candidate searched, will return initial state"; + } + return result; + } + + int measured_count = 0; + uint32_t continuous_empty_cnt = 0; + while (measured_count < options.num_measure_trials) { + VLOG(4) << "Launch a new search, current measured_count:" << measured_count; + std::vector measure_inputs; + std::vector states = SearchOneRound(options, &measure_inputs); + if (states.empty()) { // no new valid candidate achieved + ++continuous_empty_cnt; + if (continuous_empty_cnt <= kMaxRetryContinuousEmpty_) { + VLOG(4) << "No valid state searched, continuous_empty_cnt=" << continuous_empty_cnt; + continue; + } else { + LOG(WARNING) + << "OptimizeByEvolution will be exited in advance due to continuous invalid search, final measured_count=" + << measured_count; + break; + } + } + continuous_empty_cnt = 0; // reset if get valid candidates + + VLOG(4) << "ScheduleMeasurer start with input size=" << measure_inputs.size(); + std::vector measure_outputs = schedule_measurer_->Measure(measure_inputs); + CHECK_EQ(measure_outputs.size(), states.size()) + << "ScheduleMeasurer didn't output same number of MeasureOutput of states in TaskOptimizer"; + // record to database + for (size_t i = 0; i < states.size(); ++i) { + database_->AddRecord( + TuningRecord(measure_inputs[i].task->serialized_key, states[i], measure_outputs[i].execution_cost)); + } + + // update cost model + if (FLAGS_auto_schedule_use_cost_model) { + std::vector cost_model_samples(states.size()); + std::vector cost_model_labels(states.size()); + for (size_t i = 0; i < states.size(); ++i) { + cost_model_samples[i] = &(states[i]->ir_schedule.GetModule()); + cost_model_labels[i] = measure_outputs[i].execution_cost; + } + VLOG(4) << utils::StringFormat("Update CostModel with samples size=%lu,labels size=%lu", + cost_model_samples.size(), + cost_model_labels.size()); + cost_model_.Update(cost_model_samples, cost_model_labels, task_->target); + } + + // update the best + for (size_t i = 0; i < measure_outputs.size(); ++i) { + if (measure_outputs[i].execution_cost < best_cost) { + VLOG(4) << "Update best candidate with execution_cost:" << measure_outputs[i].execution_cost << "us"; + best_cost = measure_outputs[i].execution_cost; + optimized_funcs = measure_inputs[i].lowered_funcs; + } + } + + // count result size + measured_count += states.size(); + } + return result; +} + +std::vector TaskOptimizer::SearchOneRound(const TuningOptions& options, + std::vector* measure_candidates) { + std::vector states = evolutionary_search_->SearchModuleExprEpsGreedy(options); + VLOG(4) << JoinStatesDebugString("TaskOptimizer::EvolutionarySearch-Result", states, /*verbose=*/VLOG_IS_ON(5)); + + size_t valid_cnt = 0; + for (size_t i = 0; i < states.size(); ++i) { + std::vector best_exprs = states[i]->ir_schedule.GetModule().GetExprs(); + CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size()) + << "RuntimeError: Expr size is not equal to LoweredFunc size in TaskOptimizer"; + auto init_funcs = optim::IRCopy(task_->lowered_funcs); + std::vector valid_funcs; + for (size_t j = 0; j < best_exprs.size(); ++j) { + auto updated_f = UpdateFuncWithNewBody(task_->target, init_funcs[j], best_exprs[j]); + if (PruneInvalid(updated_f, task_->target)) { + VLOG(4) << "PruneInvalid states-" << i; + break; + } + valid_funcs.emplace_back(updated_f); + } + + // all functions are validated, collect this state to be measured + if (valid_funcs.size() == init_funcs.size()) { + states[valid_cnt++] = states[i]; + measure_candidates->emplace_back(MeasureInput()); + measure_candidates->back().task = task_; + measure_candidates->back().lowered_funcs = std::move(valid_funcs); + } + } + + states.erase(states.begin() + valid_cnt, states.end()); + CHECK_EQ(states.size(), measure_candidates->size()) << "result size of states not equal to measure_candidates"; + VLOG(4) << "EvolutionarySearch return size=" << states.size() << ", valid count=" << valid_cnt; + VLOG(4) << JoinStatesDebugString("TaskOptimizer::SearchOneRound-Result", states, /*verbose=*/VLOG_IS_ON(5)); + return states; +} + +// detect the limit of available shared memory on the current NVGPU with CUDA runtime +size_t GetGPUSharedMemoryLimit() { +#ifdef CINN_WITH_CUDA + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaDeviceProp prop; + CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); + VLOG(4) << utils::StringFormat("GPU-%d GPUSharedMemoryLimit=%d", device_id, prop.sharedMemPerBlock); + return prop.sharedMemPerBlock; +#else + return 0; +#endif +} + +// detect the limit of available local/stack memory on the current NVGPU with CUDA runtime +size_t GetGPULocalStackLimit() { +#ifdef CINN_WITH_CUDA + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + cudaDeviceProp prop; + CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); + size_t limit = prop.totalGlobalMem / prop.multiProcessorCount / prop.maxThreadsPerMultiProcessor; + VLOG(4) << utils::StringFormat( + "GPU-%d totalGlobalMem=%lu,maxThreadsPerMultiProcessor=%d,multiProcessorCount=%d, calculated " + "GPULocalStackLimit=%lu", + device_id, + prop.totalGlobalMem, + prop.multiProcessorCount, + prop.maxThreadsPerMultiProcessor, + limit); + return limit; +#else + return 0; +#endif +} + +// check whether usage of the specific memory type in the lowered_func exceeds hardware limit +bool IsGPUMemoryUsageExceedLimit(const ir::LoweredFunc& lowered_func, + const ir::MemoryType& used_memory_type, + const size_t limit_bytes) { + std::unordered_set visited; + size_t used_bytes_cnt = 0; + for (auto&& buf : lowered_func->temp_bufs) { + VLOG(5) << "temp buf name=" << buf->name << ", numel=" << buf->numel() << ",dtype=" << buf->dtype; + if (buf->memory_type == used_memory_type && !visited.count(buf->name)) { + used_bytes_cnt += buf->numel() * buf->dtype.bytes(); + visited.insert(buf->name); + } + } + VLOG(5) << "total used_bytes_cnt=" << used_bytes_cnt; + return used_bytes_cnt >= limit_bytes; +} + +bool PruneInvalid(const ir::LoweredFunc& lowered_func, const common::Target& target) { + static const size_t kGPUSharedMemoryLimitBytes = GetGPUSharedMemoryLimit(); + static const size_t kGPULocalStackLimitBytes = GetGPULocalStackLimit(); + + if (target == common::DefaultNVGPUTarget()) { + if (IsGPUMemoryUsageExceedLimit(lowered_func, ir::MemoryType::GPUShared, kGPUSharedMemoryLimitBytes)) { + VLOG(5) << ir::MemoryType::GPUShared << " memory usage exceeds limit, func:\n" << lowered_func; + return true; + } + + if (IsGPUMemoryUsageExceedLimit(lowered_func, ir::MemoryType::GPULocal, kGPULocalStackLimitBytes)) { + VLOG(5) << ir::MemoryType::GPULocal << " memory usage exceeds limit, func:\n" << lowered_func; + return true; + } + } + return false; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_optimizer.h b/paddle/cinn/auto_schedule/task/task_optimizer.h new file mode 100644 index 0000000000000..68fb9f8457324 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_optimizer.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/cost_model/expr_cost_model.h" +#include "cinn/auto_schedule/database/database.h" +#include "cinn/auto_schedule/measure/schedule_measurer.h" +#include "cinn/auto_schedule/search_strategy/evolutionary_search.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/tuning.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace auto_schedule { + +// This class is responsible for tuning a specific task, +// it will integrate necessary components to search the +// optimal schedule for the task. +class TaskOptimizer { + public: + TaskOptimizer(TuneTask* task, + ScheduleMeasurer* schedule_measurer, + Database* database, + utils::LinearRandomEngine::StateType rand_seed = -1); + + FunctionGroup Optimize(const TuningOptions& options); + + private: + struct Result { + std::string from; + double cost; + FunctionGroup functions; + Result(const std::string& from_type) : from(from_type), cost(std::numeric_limits::max()) {} + }; + + Result OptimizeByManual(bool need_measure); + Result OptimizeByExternal(bool need_measure); + Result OptimizeByEvolution(const TuningOptions& options); + + // call search candidates once by EvolutionarySearch and prune invalid ones + std::vector SearchOneRound(const TuningOptions& options, std::vector* measure_candidates); + + private: + // the max retry times if continuously get empty result + static constexpr uint32_t kMaxRetryContinuousEmpty_ = 3; + TuneTask* task_; + ScheduleMeasurer* schedule_measurer_; + std::unique_ptr evolutionary_search_ = nullptr; + ExprCostModel cost_model_; + Database* database_; + utils::LinearRandomEngine::StateType rand_seed_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_registry.h b/paddle/cinn/auto_schedule/task/task_registry.h new file mode 100644 index 0000000000000..ad069ecac8343 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_registry.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "cinn/ir/ir_schedule.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/utils/registry.h" + +namespace cinn { + +namespace auto_schedule { + +struct InitialTaskInfo { + std::string task_key; + ir::ModuleExpr module_expr; + + InitialTaskInfo(const std::string& task_key, const ir::ModuleExpr& module_expr) + : task_key(task_key), module_expr(module_expr) {} +}; + +// Global task registry, used to save the initial ModuleExpr of each task. +class InitialTaskRegistry : public Registry { + public: + static InitialTaskRegistry* Global() { + static InitialTaskRegistry x; + return &x; + } + + // Get the initial ModuleExpr of a task. + inline const InitialTaskInfo* Get(const std::string& task_key) { + const InitialTaskInfo* task_info = Registry::Find(task_key); + CHECK(task_info) << "InitialTaskInfo [" << task_key << "] is not registered"; + return task_info; + } + + // Check if the task info with task_key exists; + inline const bool Has(const std::string& task_key) { return nullptr != Registry::Find(task_key); } + + // Regist the initial ModuleExpr of a task into the map + inline void Regist(const std::string& task_key, const ir::ModuleExpr& module_expr) { + std::lock_guard guard(registering_mutex); + if (fmap_.count(task_key) == 0) { + InitialTaskInfo* task_info = new InitialTaskInfo(task_key, optim::IRCopy(module_expr)); + __REGISTER__(task_key, task_info); + } + } + + private: + InitialTaskRegistry() = default; + CINN_DISALLOW_COPY_AND_ASSIGN(InitialTaskRegistry); + + // Regist the initial ModuleExpr of a task. + inline InitialTaskInfo* __REGISTER__(const std::string& task_key, InitialTaskInfo* task_info) { + fmap_[task_key] = task_info; + const_list_.push_back(task_info); + entry_list_.push_back(task_info); + return task_info; + } +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/task_registry_test.cc b/paddle/cinn/auto_schedule/task/task_registry_test.cc new file mode 100644 index 0000000000000..c94f0df743e9b --- /dev/null +++ b/paddle/cinn/auto_schedule/task/task_registry_test.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/task_registry.h" + +#include +#include + +#include + +#include "cinn/auto_schedule/task/task_creator.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/utils/string.h" +#include "cinn/utils/type_defs.h" + +DECLARE_bool(auto_schedule_use_cost_model); +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace auto_schedule { + +std::vector CreateTasks(hlir::framework::Graph* graph, const common::Target& target) { + // create tasks + TaskCreator task_creator; + std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph); + + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs>("infershape"); + + std::unique_ptr op_lowerer = + std::make_unique(dtype_dict, shape_dict, target); + for (TuneTask& task : tasks) { + task.Initialize(shape_dict, dtype_dict, op_lowerer.get()); + VLOG(3) << "Add a task with serialized_key:\n" << task.serialized_key; + } + + return tasks; +} + +std::shared_ptr CreateAddProgram(const common::Target& target) { + frontend::NetBuilder builder("test"); + + auto a = builder.CreateInput(Float(32), {1, 64, 112, 112}, "A"); + auto b = builder.CreateInput(Float(32), {64}, "B"); + auto c = builder.Add(a, b, 1); + + return std::make_shared(builder.Build(), target); +} + +TEST(TestTaskRegistry, basic) { + FLAGS_auto_schedule_use_cost_model = true; + FLAGS_cinn_ir_schedule = true; + +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + std::shared_ptr graph = CreateAddProgram(target); + std::vector tasks = CreateTasks(graph.get(), target); + + InitialTaskRegistry* task_registry = InitialTaskRegistry::Global(); + + std::vector module_exprs; + for (const TuneTask& task : tasks) { + module_exprs.emplace_back(task.GetLoweredFuncBodyExprs()); + task_registry->Regist(task.serialized_key, module_exprs.back()); + } + + for (int i = 0; i < tasks.size(); ++i) { + std::string key = tasks[i].serialized_key; + VLOG(3) << "serialized_key = " << key; + ir::ModuleExpr new_expr = task_registry->Get(key)->module_expr; + + ASSERT_EQ(new_expr.GetExprs().size(), module_exprs[i].GetExprs().size()); + for (int j = 0; j < new_expr.GetExprs().size(); ++j) { + VLOG(3) << "expr " << j << " of task " << key << " : " << new_expr.GetExprs().at(j); + ASSERT_EQ(utils::GetStreamCnt(new_expr.GetExprs().at(j)), utils::GetStreamCnt(module_exprs[i].GetExprs().at(j))); + } + } + + bool flag = task_registry->Has(tasks[0].serialized_key); + ASSERT_EQ(flag, true); + + flag = task_registry->Has("not_exist"); + ASSERT_EQ(flag, false); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/tune_task.cc b/paddle/cinn/auto_schedule/task/tune_task.cc new file mode 100644 index 0000000000000..80998c3825a47 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/tune_task.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/tune_task.h" + +#include + +#include +#include + +#include "cinn/auto_schedule/analysis/analyze_ir.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace auto_schedule { + +void TuneTask::Initialize(const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + hlir::framework::OpLowerer* lower_handler) { + CHECK(lower_handler != nullptr) << "op_lowerer can't be nullptr"; + op_lowerer = lower_handler; + + // Set lowered_funcs and analyze output names. + this->lowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); + this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); + this->serialized_key = SerializeToString(shape_dict, dtype_dict); +} + +std::vector TuneTask::GetLoweredFuncBodyExprs() const { + std::vector result; + for (const ir::LoweredFunc& func : lowered_funcs) { + result.push_back(func->body); + } + return result; +} + +std::string TuneTask::SerializeToString(const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict) { + std::stringstream ss; + ss << target << "\n\n"; // print target + + // local function to print dtype,shape of out/in variables of the specified node + auto print_node_links_fn = [&](const std::vector>& links, bool is_input) { + int printed_num = 0; + for (auto&& edge : links) { + const auto* var_node = is_input ? edge->source()->safe_as() + : edge->sink()->safe_as(); + CHECK(var_node) << "var node invalid"; + auto sit = shape_dict.find(var_node->id()); + CHECK(sit != shape_dict.end()) << "can't find shape of variable:" << var_node->id(); + auto dit = dtype_dict.find(var_node->id()); + CHECK(dit != dtype_dict.end()) << "can't find dtype of variable:" << var_node->id(); + if (printed_num > 0) { + ss << ", "; + } + ++printed_num; + // TODO(CtfGo): CINN uses the names of input/output NodeData ids as arguments of the LoweredFunc in the Lower + // process, so it will result in different LoweredFuncs for two Nodes even though they represents the same + // operator. Here we add `var_node->id()` into the serialized_key to distinguish them, otherwise AutoTuner will + // get wrong TuningRecords when querying cached results from database. In the future, we should remove + // name-related limit in Lower process, to avoid duplicate tuning tasks with same operators. + ss << var_node->id() << "->" << cinn::common::Type2Str(dit->second) << "[" + utils::Join(sit->second, ",") << "]"; + } + }; + + // print each node of the subgraph + ss << "Group {\n"; + for (auto&& node : subgraph->CollectNodes()) { + ss << " ("; + print_node_links_fn(node->outlinks_in_order(), false); + ss << ") = " << node->op()->name << "("; + print_node_links_fn(node->inlinks_in_order(), true); + ss << ")\n"; + } + ss << "}\n"; + + return ss.str(); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/tune_task.h b/paddle/cinn/auto_schedule/task/tune_task.h new file mode 100644 index 0000000000000..4963a36fc4133 --- /dev/null +++ b/paddle/cinn/auto_schedule/task/tune_task.h @@ -0,0 +1,69 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include + +#include "cinn/common/target.h" +#include "cinn/common/type.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/lowered_func.h" + +namespace cinn { +namespace auto_schedule { + +class TuneTask { + public: + TuneTask() = default; + TuneTask(std::shared_ptr group) : subgraph(group) {} + // Initialize a task + void Initialize(const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict, + hlir::framework::OpLowerer* lower_handler); + // Extract bodies in lowered_funcs() and return + std::vector GetLoweredFuncBodyExprs() const; + + // In CINN, we use hlir::framework::Graph::Group to represent a fused + // sub-graph (if an op won't be fused, it will be a Group with size=1). + std::shared_ptr subgraph; + // Lower handler, Not owned + hlir::framework::OpLowerer* op_lowerer; + // target of this task + common::Target target; + // stores the initial (un-optimized) LoweredFuncs + std::vector lowered_funcs; + // names of the output arguments of lowered_funcs_ + std::unordered_set output_names; + // serialized string of this task, it contains struct,shape,dtype,input/output variable name + // of the subgraph and can be further used to hash + std::string serialized_key; + + private: + // Serialize this task as a string contains specific fields of it + std::string SerializeToString(const absl::flat_hash_map& shape_dict, + const absl::flat_hash_map& dtype_dict); +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task/tune_task_test.cc b/paddle/cinn/auto_schedule/task/tune_task_test.cc new file mode 100755 index 0000000000000..9ff7ea26392cd --- /dev/null +++ b/paddle/cinn/auto_schedule/task/tune_task_test.cc @@ -0,0 +1,339 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task/tune_task.h" + +#include + +#include +#include +#include + +#include "cinn/auto_schedule/task/task_creator.h" +#include "cinn/common/context.h" +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/framework/scope.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/utils/string.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace auto_schedule { + +using ::cinn::frontend::NetBuilder; +using ::cinn::frontend::Program; +using ::cinn::hlir::framework::OpLowerer; + +Program CreateAddProgram() { + constexpr int M = 32; + constexpr int N = 24; + + NetBuilder builder("net_builder"); + auto a = builder.CreateInput(Float(32), {M, N}, "A"); + auto b = builder.CreateInput(Float(32), {M, N}, "B"); + auto c = builder.Add(a, b); + auto d = builder.Add(a, c); + auto program = builder.Build(); + + return program; +} + +TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) { + // Auto tuner is combined with IR schedule + FLAGS_cinn_ir_schedule = true; + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + Program prog = CreateAddProgram(); + auto graph = std::make_shared(prog, target); + + TaskCreator task_creator; + std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + ASSERT_EQ(tasks.size(), 2UL); + + const auto& shape_dict = graph->GetAttrs>("infershape"); + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + OpLowerer op_lowerer(dtype_dict, shape_dict, target); + + std::stringstream ss; + for (TuneTask& task : tasks) { + task.Initialize(shape_dict, dtype_dict, &op_lowerer); + + std::vector exprs = task.GetLoweredFuncBodyExprs(); + VLOG(6) << "ir:Expr is: "; + for (const ir::Expr& e : exprs) { + VLOG(6) << e; + ss << e << std::endl; + } + } + + std::string expr_str = ss.str(); +#ifdef CINN_WITH_CUDA + std::string target_str = R"ROC( +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_1) + { + i0, i1 = axis.bind(i, j) + var_1[i, j] = (A[i, j] + B[i, j]) + } + } + } + } +} +{ + ScheduleBlock(root_0) + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_2) + { + i0_0, i1_0 = axis.bind(i, j) + var_2[i, j] = (A[i, j] + var_1[i, j]) + } + } + } + } +} +)ROC"; +#else + std::string target_str = R"ROC( +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_1) + { + i0, i1 = axis.bind(i, j) + var_1[i0, i1] = (A[i0, i1] + B[i0, i1]) + } + } + } + } +} +{ + ScheduleBlock(root_0) + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_2) + { + i0_0, i1_0 = axis.bind(i, j) + var_2[i0_0, i1_0] = (A[i0_0, i1_0] + var_1[i0_0, i1_0]) + } + } + } + } +} +)ROC"; +#endif + + EXPECT_EQ(utils::Trim(target_str), utils::Trim(expr_str)); +} + +TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) { + // Auto tuner is combined with IR schedule + FLAGS_cinn_ir_schedule = true; + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + Program prog = CreateAddProgram(); + auto graph = std::make_shared(prog, target); + ApplyPass(graph.get(), "OpFusionPass"); + + TaskCreator task_creator; + std::vector tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + + ASSERT_EQ(tasks.size(), 1UL); + + const auto& shape_dict = graph->GetAttrs>("infershape"); + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + + OpLowerer op_lowerer(dtype_dict, shape_dict, target); + + std::stringstream ss; + for (TuneTask& task : tasks) { + task.Initialize(shape_dict, dtype_dict, &op_lowerer); + + std::vector exprs = task.GetLoweredFuncBodyExprs(); + VLOG(6) << "ir:Expr is: "; + for (const ir::Expr& e : exprs) { + VLOG(6) << e; + ss << e << std::endl; + } + } + + std::string expr_str = ss.str(); +#ifdef CINN_WITH_CUDA + std::string target_str = R"ROC( +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_1) + { + i0, i1 = axis.bind(i, j) + var_1[i, j] = (A[i, j] + B[i, j]) + } + } + } + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_2) + { + i0_0, i1_0 = axis.bind(i, j) + var_2[i, j] = (A[i, j] + var_1[i, j]) + } + } + } + } + } +} +)ROC"; + +#else + std::string target_str = R"ROC( +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_1) + { + i0, i1 = axis.bind(i, j) + var_1[i0, i1] = (A[i0, i1] + B[i0, i1]) + } + } + } + serial for (i, 0, 32) + { + serial for (j, 0, 24) + { + ScheduleBlock(var_2) + { + i0_0, i1_0 = axis.bind(i, j) + var_2[i0_0, i1_0] = (A[i0_0, i1_0] + var_1[i0_0, i1_0]) + } + } + } + } + } +} +)ROC"; +#endif + + EXPECT_EQ(utils::Trim(target_str), utils::Trim(expr_str)); +} + +TEST(TuneTask, SerializeToString) { + Context::Global().ResetNameId(); +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + Program prog = CreateAddProgram(); + auto graph = std::make_shared(prog, target); + + TaskCreator task_creator; + std::vector single_tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + + const auto& shape_dict = graph->GetAttrs>("infershape"); + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + OpLowerer op_lowerer(dtype_dict, shape_dict, target); + ASSERT_EQ(single_tasks.size(), 2UL); + for (auto&& task : single_tasks) { + task.Initialize(shape_dict, dtype_dict, &op_lowerer); + } + +#ifdef CINN_WITH_CUDA + std::string single_add_str = R"ROC(Target + +Group { + (var_1->float32[32,24]) = elementwise_add(A->float32[32,24], B->float32[32,24]) +} +)ROC"; +#else + std::string single_add_str = R"ROC(Target + +Group { + (var_1->float32[32,24]) = elementwise_add(A->float32[32,24], B->float32[32,24]) +} +)ROC"; +#endif + EXPECT_EQ(single_tasks[0].serialized_key, single_add_str); + + ApplyPass(graph.get(), "OpFusionPass"); + std::vector fused_tasks = task_creator.CreateTuneTaskOpLevel(graph.get()); + ASSERT_EQ(fused_tasks.size(), 1UL); + fused_tasks[0].Initialize(shape_dict, dtype_dict, &op_lowerer); + +#ifdef CINN_WITH_CUDA + std::string fused_expected_str = R"ROC(Target + +Group { + (var_1->float32[32,24]) = elementwise_add(A->float32[32,24], B->float32[32,24]) + (var_2->float32[32,24]) = elementwise_add(A->float32[32,24], var_1->float32[32,24]) +} +)ROC"; +#else + std::string fused_expected_str = R"ROC(Target + +Group { + (var_1->float32[32,24]) = elementwise_add(A->float32[32,24], B->float32[32,24]) + (var_2->float32[32,24]) = elementwise_add(A->float32[32,24], var_1->float32[32,24]) +} +)ROC"; +#endif + EXPECT_EQ(fused_tasks[0].serialized_key, fused_expected_str); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/CMakeLists.txt b/paddle/cinn/auto_schedule/task_scheduler/CMakeLists.txt new file mode 100644 index 0000000000000..d938b027a7c5f --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/CMakeLists.txt @@ -0,0 +1,5 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS task_scheduler.cc round_robin.cc efficiency_priority.cc) + +cc_test(test_task_scheduler SRCS task_scheduler_test.cc DEPS cinncore) diff --git a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc new file mode 100644 index 0000000000000..a83f8004965c2 --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task_scheduler/efficiency_priority.h" + +namespace cinn { +namespace auto_schedule { + +int EfficiencyPriority::NextTaskId() { + while (cur_task_id_ < tasks_->size()) { + if (IsTaskToTune(&tasks_->at(cur_task_id_))) { + return cur_task_id_++; + } + ++cur_task_id_; + } + return -1; +} + +bool EfficiencyPriority::IsTaskToTune(const TuneTask* task) { return config_.minimum_gain_threshold > 0.0; } + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h new file mode 100644 index 0000000000000..af6e5272b09fe --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/efficiency_priority.h @@ -0,0 +1,39 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" + +namespace cinn { +namespace auto_schedule { + +// Schedule tasks with efficiency_priority strategy, that +// is picking a task with the maximum earnings ratio. +class EfficiencyPriority : public TaskScheduler { + public: + EfficiencyPriority(const std::vector& tasks, const Config& config) : TaskScheduler(tasks, config) {} + + const char* Name() const override { return "efficiency_priority"; }; + + int NextTaskId() override; + + private: + bool IsTaskToTune(const TuneTask* task); +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/round_robin.cc b/paddle/cinn/auto_schedule/task_scheduler/round_robin.cc new file mode 100644 index 0000000000000..37af0cee556c0 --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/round_robin.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task_scheduler/round_robin.h" + +namespace cinn { +namespace auto_schedule { + +int RoundRobin::NextTaskId() { + if (cur_task_id_ < tasks_->size()) { + return cur_task_id_++; + } + return -1; +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/round_robin.h b/paddle/cinn/auto_schedule/task_scheduler/round_robin.h new file mode 100644 index 0000000000000..55429fce92f1f --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/round_robin.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" + +namespace cinn { +namespace auto_schedule { + +// Schedule tasks with round_robin strategy, that +// is picking a task to tune once a time iteratively. +class RoundRobin : public TaskScheduler { + public: + RoundRobin(const std::vector& tasks, const Config& config) : TaskScheduler(tasks, config) {} + + const char* Name() const override { return "round_robin"; }; + + int NextTaskId() override; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc new file mode 100644 index 0000000000000..0c6f99ad73c6e --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" + +#include + +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/task_scheduler/efficiency_priority.h" +#include "cinn/auto_schedule/task_scheduler/round_robin.h" + +namespace cinn { +namespace auto_schedule { + +std::unique_ptr TaskScheduler::Make(const std::vector& tasks, + const Config& config, + const std::string& strategy) { + CHECK_GT(tasks.size(), 0) << "Empty task list"; + if (strategy == "round_robin") { + return std::make_unique(tasks, config); + } else if (strategy == "efficiency_priority") { + return std::make_unique(tasks, config); + } + + LOG(FATAL) << "Unimplemented strategy:" << strategy; + return nullptr; +} + +TaskScheduler::TaskScheduler(const std::vector& tasks, const Config& config) + : tasks_(&tasks), config_(config), cur_task_id_(0) {} + +void TaskScheduler::Reset() { cur_task_id_ = 0; } + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h new file mode 100644 index 0000000000000..cd8776bd97620 --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.h @@ -0,0 +1,67 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "cinn/auto_schedule/task/task_optimizer.h" +#include "cinn/auto_schedule/task/tune_task.h" +#include "cinn/auto_schedule/tuning.h" + +namespace cinn { +namespace auto_schedule { + +// Class for scheduling tasks to perform auto-tune +class TaskScheduler { + public: + // All configs for different schedule strategies + // will be defined here together. + struct Config { + // The minimum threshold of earnings ratio, used by EfficiencyPriority + float minimum_gain_threshold = 0.0; + }; + + // Create a TaskScheduler with the specific strategy name + // and necessary construct parameters. + static std::unique_ptr Make(const std::vector& tasks, + const Config& config, + const std::string& strategy = "round_robin"); + + // Reset associated states to schedule at the beginning + void Reset(); + + // Return the name of schedule strategy + virtual const char* Name() const = 0; + + // Select a task to tune + virtual int NextTaskId() = 0; + + protected: + // A taskScheduler object should be created with the static function Make + TaskScheduler(const std::vector& tasks, const Config& config); + + // The config for scheduling strategy + Config config_; + // The current task id to be estimated + int cur_task_id_; + // The pointer refers to all tasks + const std::vector* tasks_; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc new file mode 100644 index 0000000000000..a05b8dab3fd28 --- /dev/null +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler_test.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/auto_schedule/task_scheduler/task_scheduler.h" + +#include + +#include + +#include "cinn/auto_schedule/task_scheduler/efficiency_priority.h" +#include "cinn/auto_schedule/task_scheduler/round_robin.h" + +namespace cinn { +namespace auto_schedule { + +TEST(TaskScheduler, Make) { + std::vector tasks(3); + TaskScheduler::Config config; + + auto round_robin = TaskScheduler::Make(tasks, config); + ASSERT_STREQ(round_robin->Name(), "round_robin"); + auto efficiency_priority = TaskScheduler::Make(tasks, config, "efficiency_priority"); + ASSERT_STREQ(efficiency_priority->Name(), "efficiency_priority"); +} + +TEST(RoundRobinScheduler, NextTaskId) { + std::vector tasks(3); + TaskScheduler::Config config; + auto round_robin = TaskScheduler::Make(tasks, config); + ASSERT_EQ(0, round_robin->NextTaskId()); + ASSERT_EQ(1, round_robin->NextTaskId()); + round_robin->Reset(); + ASSERT_EQ(0, round_robin->NextTaskId()); +} + +TEST(EfficiencyPriorityScheduler, NextTaskId) { + std::vector tasks(3); + TaskScheduler::Config config; + config.minimum_gain_threshold = -1.0; + auto efficiency_priority = TaskScheduler::Make(tasks, config, "efficiency_priority"); + ASSERT_EQ(-1, efficiency_priority->NextTaskId()); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/tests/CMakeLists.txt b/paddle/cinn/auto_schedule/tests/CMakeLists.txt new file mode 100644 index 0000000000000..407400b1f241b --- /dev/null +++ b/paddle/cinn/auto_schedule/tests/CMakeLists.txt @@ -0,0 +1,5 @@ +if (WITH_CUDA AND (NOT WITH_CUDNN)) + cc_test(test_performance_comparison + ARGS "--resnet50_model_dir=${THIRD_PARTY_PATH}/ResNet50" + SRCS performance_comparison_test.cc DEPS cinncore test_program_builder) +endif() diff --git a/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc new file mode 100644 index 0000000000000..35a1e58063605 --- /dev/null +++ b/paddle/cinn/auto_schedule/tests/performance_comparison_test.cc @@ -0,0 +1,310 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#include "cinn/auto_schedule/auto_tuner.h" +#include "cinn/common/target.h" +#include "cinn/frontend/net_builder.h" +#include "cinn/frontend/optimize.h" +#include "cinn/frontend/paddle_model_convertor.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/ir/ir_base.h" +#include "cinn/runtime/flags.h" +#include "cinn/utils/data_util.h" +#include "tests/program_builder.h" + +/* This test is used as a tool to evaluate or compare performance of 3 schedules(no schedule, manual schedule, + * auto-schedule). One can specify which schedules to be evaluated through `FLAGS_evaluate_knobs` and specify which + * operator or model through `--gtest_filter=PerformanceTester.xx`, for example, `FLAGS_evaluate_knobs=4 + * --gtest_filter=PerformanceTester.Matmul` means it will evaluate auto-schedule on Matmul operator. You can refer to + * explanation of following flags or parameters for more detail. + */ + +DEFINE_string(resnet50_model_dir, "./ResNet50", "the path to paddle model resnet50."); +// Flags that control which schedule tests will be run. +// Bit with index 0 controls no schedule test, means options = 1 = "001" will run no schedule test. +// Bit with index 1 controls manual schedule test, means options = 2 = "010" will run manual schedule test. +// Bit with index 2 controls auto schedule test, means options = 4 = "100" will run auto schedule test. +// The default value is -1, which means that this flag is disabled to set the options +DEFINE_int32(evaluate_knobs, -1, "the options to control which schedule tests will be run."); +DECLARE_int32(cinn_parallel_compile_size); + +namespace cinn { +namespace auto_schedule { + +using ::cinn::hlir::framework::BuildScope; +using ::cinn::hlir::framework::Graph; +using ::cinn::hlir::framework::GraphCompiler; +using ::cinn::hlir::framework::Instruction; +using ::cinn::hlir::framework::Scope; + +class PerformanceTester : public ::testing::Test { + public: + struct Options { + // times of compiled runtime program will be executed repeatedly. + int repeat_times = 2; + // the num_tuning_rounds for auto tuning + int num_tuning_rounds = 2; + // knobs to control which schedules will be measured, refer to FLAGS_evaluate_knobs explanation + std::bitset<3> evaluate_knobs = 0UL; + }; + + void SetUp() override { FLAGS_cinn_parallel_compile_size = 0; } + + void Evaluate(const frontend::Program& program) { + if (FLAGS_evaluate_knobs >= 0) { + options_.evaluate_knobs = FLAGS_evaluate_knobs; + } + VLOG(3) << "evaluate_knobs = " << options_.evaluate_knobs; + + auto worker_fn = [this, &program]( + const std::string& schedule_name, BuildRuntimeProgramFn build_fn, bool execute = true) { + Context::Global().ResetNameId(); + VLOG(3) << "Initialize graph."; + auto graph = std::make_shared(program, target_); + VLOG(3) << "Apply graph pass."; + hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); + VLOG(3) << "Build " << schedule_name << " program."; + auto scope = BuildScope(target_, graph); + auto graph_compiler = std::make_unique(target_, scope, graph); + auto runtime_program = (this->*build_fn)(graph.get(), graph_compiler.get()); + if (execute) { + VLOG(3) << "Execute " << schedule_name << " program."; + runtime_program->ExecuteTest(options_.repeat_times); + } + }; + + // if no one is set, build no/manual schedule cases to ensure their build functions are valid + if (options_.evaluate_knobs.none()) { + worker_fn("no schedule", &PerformanceTester::BuildNoScheduleProgram, /* execute */ false); + worker_fn("manual schedule", &PerformanceTester::BuildManualScheduleProgram, /* execute */ false); + } else { + if (options_.evaluate_knobs.test(0)) { + worker_fn("no schedule", &PerformanceTester::BuildNoScheduleProgram); + } + if (options_.evaluate_knobs.test(1)) { + worker_fn("manual schedule", &PerformanceTester::BuildManualScheduleProgram); + } + if (options_.evaluate_knobs.test(2)) { + worker_fn("auto schedule", &PerformanceTester::BuildAutoScheduleProgram); + } + } + } + + protected: + using BuildRuntimeProgramFn = std::unique_ptr (PerformanceTester::*)(Graph*, + GraphCompiler*); + + std::unique_ptr BuildNoScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { + const auto& dtype_dict = graph->GetAttrs>("inferdtype"); + const auto& shape_dict = graph->GetAttrs>("infershape"); + + std::shared_ptr op_lowerer = + std::make_unique(dtype_dict, shape_dict, target_); + + GraphCompiler::CompileOptions compile_options; + compile_options.with_instantiate_variables = true; + + if (graph->fusion_groups.empty()) { + hlir::framework::ApplyPasses(graph, {"BuildNonFusedGroupsPass"}); + } + compile_options.groups = graph->fusion_groups; + + for (auto group : graph->fusion_groups) { + compile_options.lowered_funcs.push_back(op_lowerer->LowerWithoutSchedule(group)); + } + + VLOG(3) << "===========================No Schedule LoweredFunc Begin==========================="; + for (const auto& funcvec : compile_options.lowered_funcs) { + for (const auto& func : funcvec) { + VLOG(3) << func; + } + } + VLOG(3) << "===========================No Schedule LoweredFunc End============================="; + + return graph_compiler->Build(compile_options).runtime_program; + } + + std::unique_ptr BuildManualScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { + return graph_compiler->Build(); + } + + std::unique_ptr BuildAutoScheduleProgram(Graph* graph, GraphCompiler* graph_compiler) { + auto tuner = std::make_unique(target_, graph); + + AutoTuner::Config tuning_config; + TuningOptions tuning_options; + tuning_options.num_tuning_rounds = options_.num_tuning_rounds; + tuning_options.num_measure_trials = 2; + tuning_options.num_samples_per_iteration = 2; + + tuner->Initialize(tuning_config, graph_compiler); + TuningResult tuning_result = tuner->Tune(tuning_options); + + GraphCompiler::CompileOptions compile_options; + compile_options.with_instantiate_variables = true; + compile_options.Apply(tuning_result); + + VLOG(3) << "===========================Auto Schedule LoweredFunc Begin==========================="; + for (const auto& funcvec : compile_options.lowered_funcs) { + for (const auto& func : funcvec) { + VLOG(3) << func; + } + } + VLOG(3) << "===========================Auto Schedule LoweredFunc End============================="; + + return graph_compiler->Build(compile_options).runtime_program; + } + +#ifdef CINN_WITH_CUDA + Target target_ = common::DefaultNVGPUTarget(); +#else + Target target_ = common::DefaultHostTarget(); +#endif + Options options_; +}; + +constexpr int batch_size = 2; + +TEST_F(PerformanceTester, Mul) { Evaluate(tests::OpBuilder("mul").Build({{"X", {32, 16}}, {"Y", {16, 32}}})); } + +TEST_F(PerformanceTester, Add) { + Evaluate(tests::OpBuilder("elementwise_add").Build({{"X", {1, 56, 56, 256}}, {"Y", {1, 56, 56, 256}}})); +} + +TEST_F(PerformanceTester, Matmul) { + Evaluate(tests::OpBuilder("matmul").Build({{"X", {batch_size, 2048}}, {"Y", {2048, 1000}}})); +} + +TEST_F(PerformanceTester, Relu) { Evaluate(tests::OpBuilder("relu").Build({{"X", {batch_size, 64, 56, 56}}})); } + +TEST_F(PerformanceTester, Conv2d) { + std::vector strides{2, 2}; + std::vector paddings{3, 3}; + std::vector dilations{1, 1}; + int groups = 1; + std::string conv_type = "forward"; + std::string data_format = "NCHW"; + std::string padding_algorithm = "EXPLICIT"; + + Evaluate(tests::OpBuilder("conv2d").Build({{"X", {batch_size, 3, 224, 224}}, {"W", {64, 3, 7, 7}}}, + {{"stride", strides}, + {"padding", paddings}, + {"dilation", dilations}, + {"groups", groups}, + {"conv_type", conv_type}, + {"data_format", data_format}, + {"padding_algorithm", padding_algorithm}})); +} + +TEST_F(PerformanceTester, Pool2d) { + std::vector input_shape{batch_size, 64, 112, 112}; + std::string pooling_type = "max"; + std::vector ksize{3, 3}; + std::vector strides{2, 2}; + std::vector paddings{1, 1, 1, 1}; + bool ceil_mode = false; + bool exclusive = true; + bool global_pooling = false; + std::string data_format = "NCHW"; + bool adaptive = false; + std::string padding_algorithm = "EXPLICIT"; + + Evaluate(tests::OpBuilder("pool2d").Build({{"X", {batch_size, 64, 112, 112}}}, + {{"pool_type", pooling_type}, + {"kernel_size", ksize}, + {"stride_size", strides}, + {"padding_size", paddings}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}, + {"global_pooling", global_pooling}, + {"data_format", data_format}, + {"adaptive", adaptive}, + {"padding_algorithm", padding_algorithm}})); +} + +TEST_F(PerformanceTester, BatchNorm) { + std::vector input_shape{batch_size, 64, 112, 112}; + std::vector scale_shape{64}; + std::vector bias_shape{64}; + std::vector mean_shape{64}; + std::vector variance_shape{64}; + float epsilon = 1e-5f; + float momentum = 0.9f; + const std::string& data_layout = "NCHW"; + + Evaluate( + tests::OpBuilder("batch_norm") + .Build( + {{"X", {batch_size, 64, 112, 112}}, {"scale", {64}}, {"bias", {64}}, {"mean", {64}}, {"variance", {64}}}, + {{"epsilon", epsilon}, {"momentum", momentum}, {"data_layout", data_layout}})); +} + +TEST_F(PerformanceTester, Reshape) { + std::vector output_shape{batch_size, 2048}; + + Evaluate(tests::OpBuilder("reshape").Build({{"X", {batch_size, 2048, 1, 1}}}, {{"shape", output_shape}})); +} + +TEST_F(PerformanceTester, Softmax) { + std::vector axes = {-1}; + std::string mode = "fast"; + std::string data_format = "AnyLayout"; + + Evaluate(tests::OpBuilder("softmax").Build({{"X", {batch_size, 1000}}}, + {{"axes", axes}, {"mode", mode}, {"data_format", data_format}})); +} + +TEST_F(PerformanceTester, Scale) { + float scale = 1.0f; + float bias = 0.0f; + bool bias_after_scale = true; + + Evaluate(tests::OpBuilder("scale").Build({{"X", {batch_size, 1000}}}, + {{"scale", scale}, {"bias", bias}, {"bias_after_scale", bias_after_scale}})); +} + +TEST_F(PerformanceTester, LookupTable) { + int64_t padding_idx = -1; + + Evaluate( + tests::OpBuilder("lookup_table") + .Build({{"table", {50001, 768}}, {"ids", {10, 128, 1}, common::Int(64)}}, {{"padding_idx", padding_idx}})); +} + +TEST_F(PerformanceTester, Gather) { + int axis = 3; + + Evaluate(tests::OpBuilder("gather").Build( + {{"operand", {10, 12, 128, 512}}, {"index", {1, 1, 1, 128}, common::Int(32)}}, {{"axis", axis}})); +} + +// paddle model test +TEST_F(PerformanceTester, ResNet50) { + CHECK_NE(FLAGS_resnet50_model_dir, ""); + std::unordered_map> feeds = {{"inputs", {batch_size, 3, 224, 224}}}; + Evaluate(cinn::frontend::PaddleModelConvertor(common::DefaultNVGPUTarget()) + .LoadModel(FLAGS_resnet50_model_dir, true, feeds)); +} + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/auto_schedule/tuning.h b/paddle/cinn/auto_schedule/tuning.h new file mode 100644 index 0000000000000..0b2bfe66d1273 --- /dev/null +++ b/paddle/cinn/auto_schedule/tuning.h @@ -0,0 +1,91 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/ir/lowered_func.h" + +namespace cinn { +namespace auto_schedule { + +// alias a LoweredFunc array as FunctionGroup +using FunctionGroup = std::vector; +// alias a Graph::Group array as SubGraph +using SubGraphPtr = std::shared_ptr; + +// Options for tuning process +struct TuningOptions { + // The number of tuning rounds, each round will tune several tasks, + // each task involves TuningOptions.num_measure_trials measurements. + int num_tuning_rounds = 1; + + // The number of measurement trials in a task, if it is 0, + // that means the tunner will return the best + // candidate of schedule config without measurement. + int num_measure_trials = 10; + + // Every round TaskSchedule chooses some TuneTask(s) to optimize and run + // several iterations of search algorithm for a task to generate samples. + // Each iteration has num_samples_per_iteration samples. + // + // 1. if TuningOptions.num_measure_trials is 0, the autotune doesn't involve + // hardware measurements. It predicts performance by cost model. + // + // 2. num_measure_trials % num_samples_per_iteration must equal 0. + // In each round, autotune will run iterations until number of iterations + // * num_samples_per_iteration equals num_measure_trials. + int num_samples_per_iteration = 10; + + ////////////////////////////////////// + // Evolutionary Search Related Options + ////////////////////////////////////// + + // The number of picks from the stored database in each iteration + // These are best performance recorded from previous generations + // + // Note the number doesn't guaranteed returns those topk when the + // database doesn't have enough data. Evolutionary Search would get + // as many as possible without throwing errors or warnings. + int evolution_pick_database_topk = 8; + + // The number of initial populations at each generation. It contains + // the picks from database plus random generated samples. + int evolution_init_population_num = 10; + + // The number of samples generated by cross over + int evolution_cross_over_num = 0; + + // The fraction of random samples in num_samples_per_iteration. + // So the num_samples_per_iteration would have (1 - eps_greedy) best + // samples from evolutionary search and eps_greedy random samples. + // + // It explores the cases evolutionary search won't predict precisely + float evolution_eps_greedy = 0.1f; +}; + +// Result of the tuning process +struct TuningResult { + // Result of graph tuning + std::vector subgraphs; + // Result of schedule tuning + std::vector function_groups; +}; + +} // namespace auto_schedule +} // namespace cinn diff --git a/paddle/cinn/backends/CMakeLists.txt b/paddle/cinn/backends/CMakeLists.txt new file mode 100755 index 0000000000000..3949bc4e7313d --- /dev/null +++ b/paddle/cinn/backends/CMakeLists.txt @@ -0,0 +1,67 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + outputs.cc + codegen_c.cc + codegen_c_x86.cc + codegen_cuda_host.cc + extern_func_emitter.cc + extern_func_emitter_builtin.cc + function_prototype.cc + extern_func_protos.cc + extern_func_jit_register.cc + modular.cc + compiler.cc +) + +if (WITH_CUDA) + add_subdirectory(nvrtc) + list(APPEND srcs cuda_util.cc codegen_cuda_dev.cc codegen_cuda_util.cc) +endif() + +if (WITH_OPENMP) +cc_library(__x86_source_fake_lib SRCS _x86_builtin_source.cc) +endif() +add_subdirectory(llvm) + + +if (WITH_CUDA) + nv_test(test_raw_cuda_code SRCS raw_cuda_code_test.cu DEPS cinncore) +endif() + +cc_test(test_codegen_c SRCS codegen_c_test.cc DEPS cinncore ARGS ${global_test_args}) +cc_test(test_codegen_c_x86 SRCS codegen_c_x86_test.cc DEPS cinncore ARGS ${global_test_args}) +cc_test(test_generated1 SRCS generated_module1.cc DEPS cinn_runtime) +add_run_test_dependency(test_generated1 test_codegen_c) +cc_test(test_ir_schedule SRCS ir_schedule_test.cc DEPS cinncore) +include_directories(${CMAKE_SOURCE_DIR}/cinn/runtime) +if (TARGET test_generated1) + add_dependencies(test_generated1 test_codegen_c) +endif() + +if (WITH_CUDA) + nv_test(test_codegen_cuda_generate SRCS codegen_cuda_generate_test.cc DEPS cinncore) + nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore) + + if (WITH_TESTING) + nv_test(generated1_cuda SRCS generated1.cu DEPS cinncore) + add_run_test_dependency(generated1_cuda test_codegen_cuda_generate) + endif() + + nv_test(test_compiler SRCS compiler_test.cc DEPS cinncore) +else() + cc_test(test_compiler SRCS compiler_test.cc DEPS cinncore) +endif() + + +foreach(cpp ${srcs}) + set(cinnapi_src + "${cinnapi_src};cinn/backends/${cpp}" + CACHE INTERNAL "") +endforeach() + +file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + +foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/cinn/backends/_x86_builtin_source.cc b/paddle/cinn/backends/_x86_builtin_source.cc new file mode 100644 index 0000000000000..f29b3cc79ca81 --- /dev/null +++ b/paddle/cinn/backends/_x86_builtin_source.cc @@ -0,0 +1,378 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Predefined utilities in CINN BEGIN( +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include + +#include "cinn/runtime/cpu/thread_backend.h" + +#ifndef _CINN_X86_BUILTIN_SOURCE_ +#define _CINN_X86_BUILTIN_SOURCE_ +//! Vector in stack, this can only used in generated .cc file. +template +struct StackVec { + typedef T value_type; + typedef StackVec self_type; + + self_type& operator=(const StackVec& src) { + if (this != &src) { + memcpy(data_, src.data_, num_bytes()); + } + return *this; + } + + StackVec() { memset(data_, 0, num_bytes()); } + + explicit StackVec(const T* externl) : external_data_(externl) {} + + static self_type Broadcast(const value_type& v) { + self_type res; + for (size_t i = 0; i < Num; i++) res.data_[i] = v; + return res; + } + + static self_type Ramp(const value_type& base, const value_type& stride) { + self_type res; + for (size_t i = 0; i < Num; i++) { + res.data_[i] = base + stride * i; + } + } + + static self_type Load(const void* base, int32_t offset) { + self_type res; + memcpy(&res.data_[0], (const value_type*)base + offset, num_bytes()); + } + + static self_type Load(const void* base, const StackVec& offset) { + self_type res; + for (size_t i = 0; i < Num; i++) { + res.data_[i] = ((const value_type*)base)[offset[i]]; + } + } + + void Store(void* base, int32_t offset) const { + mempcpy((value_type*)base + offset, &data_[0], num_bytes()); // NOLINT + } + + inline value_type& operator[](size_t i) { return data_[i]; } + inline value_type operator[](size_t i) const { return data_[i]; } + + // binary operator between two vectors + // @{ +#define __(op__) \ + friend self_type operator op__(const self_type& a, const self_type& b) { \ + self_type res; \ + for (size_t i = 0; i < Num; i++) { \ + res.data_[i] = a[i] op__ b[i]; \ + } \ + return res; \ + } + __(+) + __(-) + __(*) + __(/) + __(%) + // @} +#undef __ + + // binary operator between a vector and a scalar + // @{ +#define __(op__) \ + friend self_type operator op__(const self_type& a, const value_type& b) { \ + self_type res; \ + for (size_t i = 0; i < Num; i++) { \ + res.data_[i] = a[i] op__ b; \ + } \ + return res; \ + } + __(+) + __(-) + __(*) + __(/) + __(%) +#undef __ + // @} + + static constexpr size_t num_bytes() { return sizeof(data_); } + + private: + T data_[Num]; + T* external_data_{nullptr}; +}; + +/** + * The vector with external data. + */ +template +struct ExternalVec { + typedef T value_type; + typedef ExternalVec self_type; + + explicit ExternalVec(T* data) : data_(data) {} + + self_type& operator=(const self_type& src) { + if (data_ != src.data_) { + memcpy(data_, src.data_, num_bytes()); + } + return *this; + } + + static self_type Load(const void* base, int32_t offset) { + self_type res((T*)base + offset); // NOLINT + return res; + } + + static constexpr size_t num_bytes() { return sizeof(value_type) * Num; } + + private: + T* data_{nullptr}; +}; + +// AVX256 load +//@{ +inline __m256 cinn_avx256_load(const float* dst) { return _mm256_load_ps(dst); } +inline __m256d cinn_avx256_load(const double* dst) { return _mm256_load_pd(dst); } +//@} +// AVX512 load +//@{ +inline __m512 cinn_avx512_load(const float* dst) { return _mm512_load_ps(dst); } +inline __m512d cinn_avx512_load(const double* dst) { return _mm512_load_pd(dst); } +//@} + +// FP32x8 * FP32x8 +// @{ +inline void cinn_avx256_add(float* dst, float* a, float* b) { + _mm256_store_ps(dst, _mm256_add_ps(_mm256_load_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_sub(float* dst, float* a, float* b) { + _mm256_store_ps(dst, _mm256_sub_ps(_mm256_load_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_mul(float* dst, float* a, float* b) { + _mm256_store_ps(dst, _mm256_mul_ps(_mm256_load_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_div(float* dst, float* a, float* b) { + _mm256_store_ps(dst, _mm256_div_ps(_mm256_load_ps(a), _mm256_load_ps(b))); +} +// @} + +// FP32x4 * float +// @{ +inline void cinn_avx256_add(float* dst, float* a, float b) { + _mm256_store_ps(dst, _mm256_add_ps(_mm256_load_ps(a), _mm256_set1_ps(b))); +} +inline void cinn_avx256_sub(float* dst, float* a, float b) { + _mm256_store_ps(dst, _mm256_sub_ps(_mm256_load_ps(a), _mm256_set1_ps(b))); +} +inline void cinn_avx256_mul(float* dst, float* a, float b) { + _mm256_store_ps(dst, _mm256_mul_ps(_mm256_load_ps(a), _mm256_set1_ps(b))); +} +inline void cinn_avx256_div(float* dst, float* a, float b) { + _mm256_store_ps(dst, _mm256_div_ps(_mm256_load_ps(a), _mm256_set1_ps(b))); +} +// @} + +// float * FP32x4 +// @{ +inline void cinn_avx256_add(float* dst, float a, float* b) { + _mm256_store_ps(dst, _mm256_add_ps(_mm256_set1_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_sub(float* dst, float a, float* b) { + _mm256_store_ps(dst, _mm256_sub_ps(_mm256_set1_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_mul(float* dst, float a, float* b) { + _mm256_store_ps(dst, _mm256_mul_ps(_mm256_set1_ps(a), _mm256_load_ps(b))); +} +inline void cinn_avx256_div(float* dst, float a, float* b) { + _mm256_store_ps(dst, _mm256_div_ps(_mm256_set1_ps(a), _mm256_load_ps(b))); +} +// @} + +// 4 x float64 +// @{ +inline void cinn_avx256_add(double* dst, double* a, double* b) { + _mm256_store_pd(dst, _mm256_add_pd(_mm256_load_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_sub(double* dst, double* a, double* b) { + _mm256_store_pd(dst, _mm256_sub_pd(_mm256_load_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_mul(double* dst, double* a, double* b) { + _mm256_store_pd(dst, _mm256_mul_pd(_mm256_load_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_div(double* dst, double* a, double* b) { + _mm256_store_pd(dst, _mm256_div_pd(_mm256_load_pd(a), _mm256_load_pd(b))); +} +// @} + +// FP32x4 * FP64 +// @{ +inline void cinn_avx256_add(double* dst, double* a, double b) { + _mm256_store_pd(dst, _mm256_add_pd(_mm256_load_pd(a), _mm256_set1_pd(b))); +} +inline void cinn_avx256_sub(double* dst, double* a, double b) { + _mm256_store_pd(dst, _mm256_sub_pd(_mm256_load_pd(a), _mm256_set1_pd(b))); +} +inline void cinn_avx256_mul(double* dst, double* a, double b) { + _mm256_store_pd(dst, _mm256_mul_pd(_mm256_load_pd(a), _mm256_set1_pd(b))); +} +inline void cinn_avx256_div(double* dst, double* a, double b) { + _mm256_store_pd(dst, _mm256_div_pd(_mm256_load_pd(a), _mm256_set1_pd(b))); +} +// @} + +// float * FP32x4 +// @{ +inline void cinn_avx256_add(double* dst, double a, double* b) { + _mm256_store_pd(dst, _mm256_add_pd(_mm256_set1_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_sub(double* dst, double a, double* b) { + _mm256_store_pd(dst, _mm256_sub_pd(_mm256_set1_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_mul(double* dst, double a, double* b) { + _mm256_store_pd(dst, _mm256_mul_pd(_mm256_set1_pd(a), _mm256_load_pd(b))); +} +inline void cinn_avx256_div(double* dst, double a, double* b) { + _mm256_store_pd(dst, _mm256_div_pd(_mm256_set1_pd(a), _mm256_load_pd(b))); +} +// @} + +//! 32 x float32 operations. +// @{ +inline void cinn_avx512_add(float* dst, float* a, float* b) { + _mm512_store_ps(dst, _mm512_add_ps(_mm512_load_ps(a), _mm512_load_ps(b))); +} +inline void cinn_avx512_sub(float* dst, float* a, float* b) { + _mm512_store_ps(dst, _mm512_sub_ps(_mm512_load_ps(a), _mm512_load_ps(b))); +} +inline void cinn_avx512_mul(float* dst, float* a, float* b) { + _mm512_store_ps(dst, _mm512_mul_ps(_mm512_load_ps(a), _mm512_load_ps(b))); +} +inline void cinn_avx512_div(float* dst, float* a, float* b) { + _mm512_store_ps(dst, _mm512_div_ps(_mm512_load_ps(a), _mm512_load_ps(b))); +} +// @} + +// FP32x4 * FP64 +// @{ +inline void cinn_avx512_add(float* dst, float* a, float b) { + _mm512_store_pd(dst, _mm512_add_pd(_mm512_load_pd(a), _mm512_set1_pd(b))); +} +inline void cinn_avx512_sub(float* dst, float* a, float b) { + _mm512_store_pd(dst, _mm512_sub_pd(_mm512_load_pd(a), _mm512_set1_pd(b))); +} +inline void cinn_avx512_mul(float* dst, float* a, float b) { + _mm512_store_pd(dst, _mm512_mul_pd(_mm512_load_pd(a), _mm512_set1_pd(b))); +} +inline void cinn_avx512_div(float* dst, float* a, float b) { + _mm512_store_pd(dst, _mm512_div_pd(_mm512_load_pd(a), _mm512_set1_pd(b))); +} +// @} + +// float * FP32x4 +// @{ +inline void cinn_avx512_add(float* dst, float a, float* b) { + _mm512_store_pd(dst, _mm512_add_pd(_mm512_set1_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_sub(float* dst, float a, float* b) { + _mm512_store_pd(dst, _mm512_sub_pd(_mm512_set1_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_mul(float* dst, float a, float* b) { + _mm512_store_pd(dst, _mm512_mul_pd(_mm512_set1_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_div(float* dst, float a, float* b) { + _mm512_store_pd(dst, _mm512_div_pd(_mm512_set1_pd(a), _mm512_load_pd(b))); +} +// @} + +//! 16 x float32 operations. +// @{ +inline void cinn_avx512_add(double* dst, double* a, double* b) { + _mm512_store_pd(dst, _mm512_add_pd(_mm512_load_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_sub(double* dst, double* a, double* b) { + _mm512_store_pd(dst, _mm512_sub_pd(_mm512_load_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_mul(double* dst, double* a, double* b) { + _mm512_store_pd(dst, _mm512_mul_pd(_mm512_load_pd(a), _mm512_load_pd(b))); +} +inline void cinn_avx512_div(double* dst, double* a, double* b) { + _mm512_store_pd(dst, _mm512_div_pd(_mm512_load_pd(a), _mm512_load_pd(b))); +} +// @} + +inline __m512 cinn_avx512_add(const __m512& a, const __m512& b); + +inline __m256 cinn_avx256_add_float(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); } +inline __m256d cinn_avx256_add_double(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); } +inline __m512 cinn_avx512_add_float(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); } +inline __m512d cinn_avx512_add_double(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } + +//! set1 +// @{ +inline __m256 cinn_avx256_set1(float value) { return _mm256_set1_ps(value); } +inline __m256d cinn_avx256_set1(double value) { return _mm256_set1_pd(value); } +inline __m512 cinn_avx512_set1(float value) { return _mm512_set1_ps(value); } +inline __m512d cinn_avx512_set1(double value) { return _mm512_set1_pd(value); } +// @} + +//! store +// @{ +inline void cinn_avx512_store(float* dst, const __m512& x) { _mm512_store_ps(dst, x); } +inline void cinn_avx512_store(double* dst, const __m512d& x) { _mm512_store_pd(dst, x); } +inline void cinn_avx256_store(float* dst, const __m256& x) { _mm256_store_ps(dst, x); } +inline void cinn_avx256_store(double* dst, const __m256d& x) { _mm256_store_pd(dst, x); } +// @} + +//! add +// @{ +inline __m256 cinn_avx256_add(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); } +inline __m256d cinn_avx256_add(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); } +inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); } +inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } +// @} + +//! mul +// @{ +inline __m256 cinn_avx256_mul(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); } +inline __m256d cinn_avx256_mul(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); } +inline __m512 cinn_avx512_mul(const __m512& a, const __m512& b) { return _mm512_mul_ps(a, b); } +inline __m512d cinn_avx512_mul(const __m512d& a, const __m512d& b) { return _mm512_mul_pd(a, b); } +// @} + +//! fma +// @{ +inline __m128 cinn_avx128_fma(const __m128& a, const __m128& b, const __m128& c) { return _mm_fmadd_ps(a, b, c); } +inline __m128d cinn_avx128_fma(const __m128d& a, const __m128d& b, const __m128d& c) { return _mm_fmadd_pd(a, b, c); } +inline __m256 cinn_avx256_fma(const __m256& a, const __m256& b, const __m256& c) { return _mm256_fmadd_ps(a, b, c); } +inline __m256d cinn_avx256_fma(const __m256d& a, const __m256d& b, const __m256d& c) { + return _mm256_fmadd_pd(a, b, c); +} +inline __m512 cinn_avx512_fma(const __m512& a, const __m512& b, const __m512& c) { return _mm512_fmadd_ps(a, b, c); } +inline __m512d cinn_avx512_fma(const __m512d& a, const __m512d& b, const __m512d& c) { + return _mm512_fmadd_pd(a, b, c); +} +// @} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +/// )END Predefined utilities in CINN +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // _CINN_X86_BUILTIN_SOURCE_ diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc new file mode 100644 index 0000000000000..a5a26ecea027c --- /dev/null +++ b/paddle/cinn/backends/codegen_c.cc @@ -0,0 +1,868 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_c.h" + +#include +#include + +#include "cinn/backends/extern_func_emitter.h" +#include "cinn/backends/extern_func_emitter_builtin.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_verify.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/remove_nested_block.h" +#include "cinn/runtime/cpu/thread_backend.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" + +//! Root of the builtin code. +DECLARE_string(cinn_x86_builtin_code_root); + +namespace cinn { +namespace backends { +using namespace utils; // NOLINT +using cinn::common::float16; + +const char *kCKeywordRestrict = "__restrict__"; + +void CodeGenC::Compile(const ir::Module &module, const Outputs &outputs) { + ir::IrVerify(Expr(module)); + + if (!outputs.c_header_name.empty()) { + auto source = Compile(module, OutputKind::CHeader); + std::ofstream file(outputs.c_header_name); + CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; + file << source; + file.close(); + LOG(WARNING) << "Output C header to file " << outputs.c_header_name; + } + + if (!outputs.c_source_name.empty()) { + auto source = Compile(module, OutputKind::CImpl); + std::ofstream file(outputs.c_source_name); + CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name; + file << source; + file.close(); + LOG(WARNING) << "Output C source to file " << outputs.c_source_name; + } +} + +CodeGenC::CodeGenC(Target target) : ir::IrPrinter(ss_) {} + +std::string CodeGenC::Compile(const ir::Module &module, OutputKind output_kind) { + if (output_kind == OutputKind::CHeader) { + GenerateHeaderFile(module); + } else if (output_kind == OutputKind::CImpl) { + PrintIncludes(); + + if (inline_builtin_codes_) PrintBuiltinCodes(); + + std::vector buffers; + for (auto &buffer : module->buffers) { + buffers.emplace_back(buffer.as_buffer_ref()); + } + + for (auto &func : module.functions()) { + Compile(func); + } + } else { + LOG(FATAL) << "Not supported OutputKind"; + } + return ss_.str(); +} +std::string CodeGenC::Compile(const ir::LoweredFunc &function) { + CHECK(function.defined()); + Print(function); + os() << "\n\n"; + return ss_.str(); +} + +std::string CodeGenC::GetTypeName(Type type) { + // common scalar type +#define GET_SCALAR_TYPE(pred_expr, scalar_name) \ + if (pred_expr) { \ + return scalar_name; \ + } + + GET_SCALAR_TYPE(type.is_void(), "void"); + GET_SCALAR_TYPE(type.is_bool(), "bool"); + + GET_SCALAR_TYPE(type.is_int(8), "int8_t"); + GET_SCALAR_TYPE(type.is_int(16), "int16_t"); + GET_SCALAR_TYPE(type.is_int(32), "int32_t"); + GET_SCALAR_TYPE(type.is_int(64), "int64_t"); + + GET_SCALAR_TYPE(type.is_uint(8), "uint8_t"); + GET_SCALAR_TYPE(type.is_uint(16), "uint16_t"); + GET_SCALAR_TYPE(type.is_uint(32), "uint32_t"); + GET_SCALAR_TYPE(type.is_uint(64), "uint64_t"); + + GET_SCALAR_TYPE(type.is_bfloat16(), "bfloat16"); + GET_SCALAR_TYPE(type.is_float16(), "float16"); + GET_SCALAR_TYPE(type.is_float(32), "float") + GET_SCALAR_TYPE(type.is_float(64), "double") +#undef GET_SCALAR_TYPE + + // customized_type + if (type.is_customized_type()) { + CHECK(!type.customized_type().empty()) << "customized_type can't be empty."; + auto customized_name = type.customized_type(); + // get name of a cuda built-in vector type, it is started with a 'CudaVectorType::' prefix + if (utils::Startswith(customized_name, common::customized_type::kcuda_builtin_vector_t)) { + customized_name.erase(0, strlen(common::customized_type::kcuda_builtin_vector_t)); + } + return customized_name; + } + + // other types are not implemented yet + CINN_NOT_IMPLEMENTED + return ""; +} + +std::string CodeGenC::GetTypeRepr(Type type) { + std::string str; + if (type.is_cpp_const()) { + str = "const "; + } + + str += GetTypeName(type); + if (type.is_cpp_handle()) { + str += "*"; + } else if (type.is_cpp_handle2()) { + str += "**"; + } + return str; +} +void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::StringImm *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Add *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Sub *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Mod *op) { + auto copied = op->b(); + optim::Simplify(&copied); + if (copied.is_constant()) { + int temp = (int)(copied.get_constant()); + if ((temp & (temp - 1)) == 0) { + os() << "("; + Print(op->a()); + os() << " & "; + os() << std::to_string(temp - 1); + os() << ")"; + return; + } + } + PrintBinaryOp("%", op); +} +void CodeGenC::Visit(const ir::EQ *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::NE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::LE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GT *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::GE *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::And *op) { PrintBinaryOp("&&", op); } +void CodeGenC::Visit(const ir::Or *op) { PrintBinaryOp("||", op); } +void CodeGenC::Visit(const ir::Min *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Max *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Minus *op) { IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Not *op) { + os() << "(!"; + IrPrinter::Print(op->v()); + os() << ")"; +} +void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); } +void CodeGenC::Visit(const ir::For *op) { + Expr extent = op->extent; + Expr min = op->min; + int num_task = 1; + if (op->is_parallel()) { + os() << "int num_task = max_concurrency();\n"; + DoIndent(); + os() << "omp_set_num_threads(num_task);\n"; + DoIndent(); + os() << "auto flambda = [=](int task_id, int num_task) -> int {\n"; + IncIndent(); + DoIndent(); + os() << "int n_per_task = "; + Expr num_task_var = Var("num_task"); + Print((op->extent + num_task_var - 1) / num_task_var); + os() << ";\n"; + CHECK_EQ(min.as_int32(), 0); + auto task_id = Var("task_id"); + auto n_per_task = Var("n_per_task"); + min = task_id * n_per_task; + extent = (task_id + 1) * n_per_task; + DoIndent(); + } + os() << "for ("; + os() << GetTypeRepr(Int(32)); + os() << " " << op->loop_var->name; + os() << " = "; + Print(min); + os() << "; "; + os() << op->loop_var->name; + os() << " < "; + Print(op->extent); + if (op->is_parallel()) { + os() << " && "; + os() << op->loop_var->name; + os() << " < "; + Print(extent); + } + os() << "; "; + + os() << op->loop_var->name; + os() << " += 1"; + os() << ") "; + + Print(op->body); + if (op->is_parallel()) { + os() << "\n"; + DoIndent(); + os() << "return 0;\n"; + DecIndent(); + DoIndent(); + os() << "};\n"; + os() << "#pragma omp parallel num_threads(num_task)\n"; + DoIndent(); + os() << "{\n"; + IncIndent(); + DoIndent(); + os() << "int task_id = omp_get_thread_num();\n"; + DoIndent(); + os() << "flambda(task_id, num_task);\n"; + DecIndent(); + DoIndent(); + os() << "}"; + } +} +void CodeGenC::Visit(const ir::PolyFor *op) { + os() << "for ("; + os() << GetTypeRepr(Int(32)); + os() << " " << op->iterator->name; + os() << " = "; + Print(op->init); + os() << "; "; + Print(op->condition); + os() << "; "; + + os() << op->iterator->name; + os() << " += "; + Print(op->inc); + os() << ") "; + + Print(op->body); +} +void CodeGenC::Visit(const ir::Select *op) { + os() << "("; + os() << "("; + Print(op->condition); + os() << ") ? "; + Print(op->true_value); + os() << " : "; + Print(op->false_value); + os() << ")"; +} +void CodeGenC::Visit(const ir::IfThenElse *op) { + os() << "if ("; + Print(op->condition); + os() << ") {\n"; + + if (!op->true_case.As()) IncIndent(); + DoIndent(); + Print(op->true_case); + if (!op->true_case.As()) os() << ";"; + os() << "\n"; + + if (!op->true_case.As()) DecIndent(); + + DoIndent(); + os() << "}"; + + if (op->false_case.defined()) { + os() << " else {\n"; + + if (!op->true_case.As()) IncIndent(); + DoIndent(); + Print(op->false_case); + if (!op->false_case.As()) os() << ";"; + os() << "\n"; + if (!op->true_case.As()) DecIndent(); + + DoIndent(); + os() << "}"; + } +} +void CodeGenC::Visit(const ir::Block *op) { + os() << "{\n"; + + IncIndent(); + + for (int i = 0; i < op->stmts.size() - 1; i++) { + DoIndent(); + Print(op->stmts[i]); + os() << ";\n"; + } + if (op->stmts.size() >= 1) { + DoIndent(); + Print(op->stmts.back()); + os() << ";"; + } + + DecIndent(); + os() << "\n"; + DoIndent(); + os() << "}"; +} +void CodeGenC::Visit(const ir::Call *op) { + if (op->name == runtime::intrinsic::buffer_malloc) { + PrintCall_buffer_malloc(op); + } else if (op->name == runtime::intrinsic::pod_values_to_array_repr) { + PrintCall_pod_values_to_array(op); + } else if (op->is_intrinsic_call()) { + os() << op->name << "("; + PrintCallArgs(op); + os() << ")"; + } else if (op->is_cinn_call()) { // call CINN LoweredFunc + os() << op->name << "("; + PrintCallArgs(op); + os() << ")"; + } else if (op->is_extern_call()) { + const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(ExternFuncID{backend_C, op->name.c_str()}); + if (!fn_name.empty()) { + ExternFunctionLLVMEmitter emitter(fn_name); + emitter.BindCodeGen(this); + emitter.Emit(op); + } else { + CHECK(!op->read_args.empty() || !op->write_args.empty()); + os() << op->name << "("; + PrintCallArgs(op); + os() << ")"; + } + } else { + CINN_NOT_IMPLEMENTED + } +} +void CodeGenC::PrintCallArgs(const ir::Call *op) { + if (!op->read_args.empty()) { + for (int i = 0; i < op->read_args.size() - 1; i++) { + Print(op->read_args[i]); + os() << ", "; + } + Print(op->read_args.back()); + } + if (!op->write_args.empty()) { + if (!op->read_args.empty()) os() << ", "; + + for (int i = 0; i < op->write_args.size() - 1; i++) { + Print(op->write_args[i]); + os() << ", "; + } + Print(op->write_args.back()); + } +} + +void CodeGenC::PrintCall_buffer_malloc(const ir::Call *op) { + CHECK_EQ(op->read_args.size(), 2UL); + os() << op->name << "("; + PrintCastExpr("void*", op->read_args[0]); + os() << ", "; + os() << op->read_args[1]; + os() << ")"; +} + +void CodeGenC::PrintCall_cinn_pod_value_to_(const ir::Call *op) { + CHECK_EQ(op->read_args.size(), 1UL); + os() << op->name << "("; + os() << "&("; + Print(op->read_args[0]); + os() << ")"; + os() << ")"; +} + +void CodeGenC::PrintCall_get_address(const ir::Call *op) { + CHECK_EQ(op->read_args.size(), 1UL); + CHECK(op->write_args.empty()); + auto *read_var = op->read_args.front().as_var(); + auto *read_buf = op->read_args.front().as_buffer(); + CHECK(read_var || read_buf) << "Only Var or Buffer can get address"; + + if (read_var) { + if (read_var->type().lanes() <= 1) os() << "&"; + os() << read_var->name; + } else if (read_buf) { + if (read_buf->type().lanes() <= 1) os() << "&"; + os() << read_buf->name; + } else { + CINN_NOT_IMPLEMENTED + } +} + +void CodeGenC::PrintCall_pod_values_to_array(const ir::Call *op) { + CHECK(!op->read_args.empty()); + CHECK_EQ(op->write_args.size(), 1UL); + auto output_var = op->write_args.front().as_var_ref(); + CHECK(output_var.defined()); + + std::vector arg_names; + for (auto &arg : op->read_args) { + auto arg_var = arg.as_var(); + CHECK(arg_var); + arg_names.push_back(arg_var->name); + } + + os() << "cinn_pod_value_t " << output_var->name << "[] = "; + os() << "{ "; + + os() << utils::Join(arg_names, ", "); + + os() << " }"; +} + +void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; } + +void CodeGenC::Visit(const ir::Load *op) { + Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); + if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address. + CHECK(op->type().is_vector()); + PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); + os() << "::" + << "Load("; + os() << op->tensor.As()->name; + os() << ","; + Print(dense_strided_ramp); + os() << ")"; + } else if (op->index().type().is_vector()) { + // gather + CHECK(op->type().is_vector()); + PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); + os() << "::Load("; + os() << op->tensor.As()->name; + os() << ","; + Print(op->index()); + os() << ")"; + } else if (op->is_addr_tensor()) { + auto *tensor = op->tensor.As(); + os() << tensor->name << "["; + Print(op->index()); + os() << "]"; + } else { + IrPrinter::Visit(op); + } +} + +void CodeGenC::Visit(const ir::Store *op) { + CHECK(op->is_addr_tensor()); + + auto *tensor = op->tensor.As(); + CHECK(tensor); + os() << tensor->name << "["; + Print(op->index()); + os() << "]"; + os() << " = "; + Print(op->value); +} +void CodeGenC::Visit(const ir::Alloc *op) { + os() << runtime::intrinsic::buffer_malloc; + os() << "("; + os() << "(void*)(0), "; + + auto *buffer = op->destination.As(); + os() << buffer->name; + os() << ")"; +} + +void CodeGenC::Visit(const ir::Free *op) { + os() << runtime::intrinsic::buffer_free; + os() << "("; + os() << "(void*)(0), "; + + auto *buffer = op->destination.As(); + os() << buffer->name; + os() << ")"; +} + +void CodeGenC::Visit(const ir::_Buffer_ *op) { os() << op->name; } +void CodeGenC::Visit(const ir::_Tensor_ *op) { os() << op->buffer->name; } +void CodeGenC::Visit(const ir::Let *op) { + bool is_vec = false; + CHECK(op->type().valid()); + if (op->body.defined() && op->body.As()) { + // broadcast's type is hard to print, so use c++11 auto instead. + os() << "auto"; + is_vec = true; + } else { + os() << GetTypeRepr(op->type()); + } + + os() << " "; + Print(op->symbol); + + // native C array. + if (op->type().lanes() > 1 && !is_vec) { + os() << "[" << op->type().lanes() << "]"; + } + + if (op->body.defined()) { + os() << " = "; + Print(op->body); + } +} + +void CodeGenC::Visit(const ir::Reduce *op) { + LOG(FATAL) << "Reduce IR is just for internal representation, should not be used for CodeGen."; +} + +void CodeGenC::Visit(const ir::Ramp *op) { + os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) << ">::Ramp("; + Print(op->base); + os() << ", "; + Print(op->stride); + os() << ", "; + os() << op->lanes; + os() << ")"; +} + +void CodeGenC::Visit(const ir::Broadcast *op) { + os() << "StackVec<" << op->lanes << "," << GetTypeRepr(op->type().ElementOf()) << ">::Broadcast("; + Print(op->value); + os() << ", "; + os() << op->lanes << ")"; +} + +void CodeGenC::Visit(const ir::FracOp *op) { ir::IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Sum *op) { ir::IrPrinter::Visit(op); } +void CodeGenC::Visit(const ir::Product *op) { ir::IrPrinter::Visit(op); } + +void CodeGenC::PrintCastExpr(const Type &type, Expr e) { + os() << "((" << GetTypeRepr(type) << ")"; + os() << "("; + Print(e); + os() << "))"; +} +void CodeGenC::PrintCastExpr(const std::string &type, Expr e) { + os() << "(" << type << ")"; + os() << "("; + Print(e); + os() << ")"; +} + +void CodeGenC::PrintShape(const std::vector &shape, char leftb, char rightb) { + os() << leftb << " "; + + for (int i = 0; i < shape.size() - 1; i++) { + Print(shape[i]); + os() << ", "; + } + if (shape.size() > 1) Print(shape.back()); + + os() << " " << rightb; +} + +void CodeGenC::Visit(const ir::_LoweredFunc_ *op) { + PrintFunctionDeclaration(op); + os() << "\n"; + + DoIndent(); + + CHECK_EQ(op->alloc_output_buffer_exprs.size(), op->dealloc_output_buffer_exprs.size()) + << "the count of allocation and deallocaton expressions is not match"; + + std::vector new_body; + + std::vector create_temp_buffers = op->PrepareCreateTempBufferExprs(); + std::vector alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); + std::vector dealloca_temp_buffers = op->PrepareDeallocTempBufferExprs(); +#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(op->field__), std::end(op->field__)); + APPEND_TO_NEW_BODY(argument_prepare_exprs) + new_body.insert(std::end(new_body), std::begin(create_temp_buffers), std::end(create_temp_buffers)); + APPEND_TO_NEW_BODY(alloc_output_buffer_exprs) + new_body.insert(std::end(new_body), std::begin(alloca_temp_buffers), std::end(alloca_temp_buffers)); + APPEND_TO_NEW_BODY(buffer_data_cast_exprs) + new_body.push_back(op->body); + new_body.insert(std::end(new_body), std::begin(dealloca_temp_buffers), std::end(dealloca_temp_buffers)); + APPEND_TO_NEW_BODY(dealloc_output_buffer_exprs) + + Expr func_body = ir::Block::Make(new_body); + + optim::RemoveNestedBlock(&func_body); + + Print(func_body); +} +void CodeGenC::PrintIncludes() { + os() << "#include \n"; + os() << "#include \n"; + os() << "\n"; +} + +void CodeGenC::PrintFileGuardOpen(const std::string &name) { + os() << utils::StringFormat("#ifndef _%s_CINN_H_\n", Uppercase(name).c_str()); + os() << utils::StringFormat("#define _%s_CINN_H_\n", Uppercase(name).c_str()); + os() << "\n"; +} +void CodeGenC::PrintFileGuardClose(const std::string &module_name) { + os() << utils::StringFormat("#endif // _%s_CINN_H_\n", Uppercase(module_name).c_str()); +} + +void CodeGenC::PrintBufferCreation(const std::vector &buffers) { + for (auto &buffer : buffers) { + // Ignore the buffer in other devices. + if (!buffer->is_on_host()) continue; + DoIndent(); + auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); + Var variable = ir::_Var_::Make(buffer->name, buffer_ptr_type); + auto expr = ir::intrinsics::BufferCreate::Make(buffer); + expr = ir::Let::Make(variable, expr); + Print(expr); + os() << ";\n"; + } +} + +void CodeGenC::PrintBufferDestroy(const std::vector &buffers) { + for (auto &buffer : buffers) { + DoIndent(); + Print(buffer.DestroyExpr()); + os() << ";\n"; + } +} + +void CodeGenC::GenerateHeaderFile(const ir::Module &module) { + PrintFileGuardOpen(module.name()); + PrintIncludes(); + + for (auto &func : module.functions()) { + PrintFunctionDeclaration(func.As()); + os() << ";\n"; + os() << "\n\n"; + } + + PrintFileGuardClose(module.name()); +} + +void CodeGenC::PrintFuncArg(const ir::Argument &arg) { + if (arg.is_buffer()) { + if (arg.is_input()) { + os() << "const struct cinn_buffer_t *"; + } else { + os() << "struct cinn_buffer_t *"; + } + } else if (arg.is_var()) { + os() << GetTypeRepr(arg.type()) << " "; + os() << arg.name(); + } else { + CINN_NOT_IMPLEMENTED + } + os() << arg.name(); +} + +void CodeGenC::PrintRuntimeType(const cinn_type_t &type) { + if (type == cinn_bool_t()) { + os() << "cinn_bool_t()"; + } else if (type == cinn_int8_t()) { + os() << "cinn_int8_t()"; + } else if (type == cinn_int16_t()) { + os() << "cinn_int16_t()"; + } else if (type == cinn_int32_t()) { + os() << "cinn_int32_t()"; + } else if (type == cinn_int64_t()) { + os() << "cinn_int64_t()"; + } else if (type == cinn_uint8_t()) { + os() << "cinn_uint8_t()"; + } else if (type == cinn_uint16_t()) { + os() << "cinn_uint16_t()"; + } else if (type == cinn_uint32_t()) { + os() << "cinn_uint32_t()"; + } else if (type == cinn_uint64_t()) { + os() << "cinn_uint64_t()"; + } else if (type == cinn_bfloat16_t()) { + os() << "cinn_bfloat16_t()"; + } else if (type == cinn_float16_t()) { + os() << "cinn_float16_t()"; + } else if (type == cinn_float32_t()) { + os() << "cinn_float32_t()"; + } else if (type == cinn_float64_t()) { + os() << "cinn_float64_t()"; + } else { + LOG(FATAL) << "Unknown type is not supported to print"; + } +} + +void CodeGenC::PrintStackVecType(Type type, int lanes) { + os() << "StackedVec<" << GetTypeRepr(type) << "," << lanes << ">"; +} + +void CodeGenC::Visit(const ir::PrimitiveNode *op) { CINN_NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::_BufferRange_ *op) { CINN_NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::ScheduleBlock *op) { CINN_NOT_IMPLEMENTED } +void CodeGenC::Visit(const ir::ScheduleBlockRealize *op) { CINN_NOT_IMPLEMENTED } + +void CodeGenC::Visit(const ir::IntrinsicOp *op) { + switch (op->getKind()) { +#define __(x) \ + case ir::IntrinsicKind::k##x: \ + Visit(llvm::dyn_cast(op)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } +} + +void CodeGenC::Visit(const ir::intrinsics::BufferGetDataHandle *op) { + os() << op->buffer.as_buffer()->name; + os() << "->"; + os() << "memory"; +} + +void CodeGenC::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) { + os() << op->buffer.as_buffer()->name; + os() << "->"; + os() << "memory"; +} + +void CodeGenC::Visit(const ir::intrinsics::PodValueToX *op) { + auto to_type = op->GetOutputType(0); + if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_float; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_double; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_float16; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_bool; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_int8; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_int16; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_int32; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_int64; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_uint8; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_uint16; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_uint32; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_uint64; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_void_p; + } else if (to_type == type_of()) { + os() << runtime::intrinsic::pod_value_to_buffer_p; + } else { + LOG(FATAL) << "Not supported type: " << to_type; + } + + os() << "("; + Print(op->pod_value_ptr); + os() << ")"; +} + +void CodeGenC::Visit(const ir::intrinsics::BufferCreate *op) { + const ir::_Buffer_ *buffer_arg = op->buffer.as_buffer(); + CHECK(buffer_arg); + + os() << runtime::intrinsic::buffer_create; + os() << "("; + PrintCastExpr("cinn_device_kind_t", Expr(buffer_arg->target.runtime_arch())); + os() << "/*target*/, "; + PrintRuntimeType(runtime::ToRuntimeType(buffer_arg->dtype.ElementOf())); + os() << ", "; + PrintShape(buffer_arg->shape); + if (buffer_arg->data_alignment > 0) { + os() << ", " << buffer_arg->data_alignment << "/*align*/"; + } + os() << ")"; +} + +void CodeGenC::Visit(const ir::intrinsics::GetAddr *op) { + if (op->data.as_buffer()) { + os() << "&" << op->data.as_buffer()->name; + } else if (op->data.as_var()) { + os() << "&" << op->data.as_var()->name; + } else { + os() << "&("; + Print(op->data); + os() << ")"; + } +} + +void CodeGenC::Visit(const ir::intrinsics::ArgsConstruct *op) { + os() << runtime::intrinsic::args_construct_repr << "("; + os() << op->var->name << ", "; + os() << op->args.size() << ", "; + for (int i = 0; i < op->args.size() - 1; i++) { + Print(op->args[i]); + os() << ", "; + } + if (!op->args.empty()) { + Print(op->args.back()); + } + os() << ")"; +} + +void CodeGenC::Visit(const ir::intrinsics::BuiltinIntrin *op) { + os() << op->name << "("; + if (!op->args.empty()) { + for (int i = 0; i < op->args.size() - 1; i++) { + Print(op->args[i]); + os() << ", "; + } + Print(op->args.back()); + } + os() << ")"; +} + +std::string ReadWholeFile(const std::string &path) { + CHECK(!path.empty()); + std::ifstream file(path); + CHECK(file.is_open()) << "Failed to open file: " << path; + std::stringstream ss; + ss << file.rdbuf(); + return ss.str(); +} + +void CodeGenC::PrintBuiltinCodes() { + CHECK(!FLAGS_cinn_x86_builtin_code_root.empty()) << "The flag cinn_x86_builtin_code_root should be set first"; + + const std::string x86_code_file = "_x86_builtin_source.cc"; + + auto source = ReadWholeFile(FLAGS_cinn_x86_builtin_code_root + "/" + x86_code_file); + + os() << source << "\n"; +} + +namespace detail { + +Expr StridedRampBase(Expr e, int stride) { + auto *ramp_n = e.As(); + if (ramp_n) { + auto *iv = ramp_n->stride.As(); + if (iv && iv->value == stride) return ramp_n->base; + } + return Expr(); +} + +} // namespace detail + +} // namespace backends + +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_c.h b/paddle/cinn/backends/codegen_c.h new file mode 100755 index 0000000000000..42458d549bed3 --- /dev/null +++ b/paddle/cinn/backends/codegen_c.h @@ -0,0 +1,127 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/module.h" +#include "cinn/lang/packed_func.h" +#include "cinn/runtime/cinn_runtime.h" + +namespace cinn { + +namespace ir { +class Module; +} // namespace ir + +namespace backends { + +//! keyword of __restrict__. +extern const char* kCKeywordRestrict; + +class CodeGenC : public ir::IrPrinter { + public: + enum class OutputKind { + CHeader, //! output the C header file. + CImpl, //! output the C implementation file. + }; + + explicit CodeGenC(Target target); + + void Compile(const ir::Module& module, const Outputs& outputs); + + virtual std::string Compile(const ir::Module& module, OutputKind output_kind); + + //! Disable inline the builtin codes(too large) for simpler string comparison. + void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; } + + protected: + std::string Compile(const ir::LoweredFunc& function); + std::string Compile(const ir::Buffer& buffer); + + void GenerateHeaderFile(const ir::Module& module); + + std::string GetTypeName(Type type); + + std::string GetTypeRepr(Type type); + //! type cast, print like "int(x)" + // @{ + void PrintCastExpr(const Type& type, Expr e); + void PrintCastExpr(const std::string& type, Expr e); + // @} + + void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) { + os() << "void " << op->name << "("; + os() << "void* _args, int32_t num_args"; + os() << ")"; + } + + void PrintShape(const std::vector& shape, char leftb = '{', char rightb = '}'); + + virtual void PrintIncludes(); + void PrintBuiltinCodes(); + void PrintFileGuardOpen(const std::string& module_name); + void PrintFileGuardClose(const std::string& module_name); + + //! Create the buffers in global scope(just creation without allocating them). + void PrintBufferCreation(const std::vector& buffers); + void PrintBufferDestroy(const std::vector& buffers); + void PrintRuntimeType(const cinn_type_t& type); + + //! Print different kinds of Calls. + // @{ + void PrintCallArgs(const ir::Call* call); + void PrintCall_buffer_malloc(const ir::Call* op); + void PrintCall_cinn_pod_value_to_(const ir::Call* op); + void PrintCall_get_address(const ir::Call* op); + void PrintCall_pod_values_to_array(const ir::Call* op); + // @} + +#define __DEFINE_VISIT(op__) void Visit(const ir::op__* op) override; + NODETY_FORALL(__DEFINE_VISIT) +#undef __DEFINE_VISIT + +#define __DEFINE_VISIT(op__) void Visit(const ir::intrinsics::op__* op) override; + INTRINSIC_KIND_FOR_EACH(__DEFINE_VISIT) +#undef __DEFINE_VISIT + + void PrintFuncArg(const ir::Argument& arg); + + void PrintStackVecType(Type type, int lanes); + + friend class ExternFunctionEmitter; + + protected: + Target target_; + std::stringstream ss_; + bool inline_builtin_codes_{true}; +}; + +namespace detail { + +Expr StridedRampBase(Expr e, int stride); + +} // namespace detail + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_c_test.cc b/paddle/cinn/backends/codegen_c_test.cc new file mode 100755 index 0000000000000..3a95774c2f53f --- /dev/null +++ b/paddle/cinn/backends/codegen_c_test.cc @@ -0,0 +1,436 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_c.h" + +#include + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/module.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/runtime/cpu/use_extern_funcs.h" + +namespace cinn { +namespace backends { + +using ir::Module; +using lang::Compute; +using lang::Lower; +using lang::Placeholder; +using utils::StringFormat; +using utils::Trim; + +std::tuple CreateTensor1() { + Expr M(100); + Expr N(20); + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + lang::Buffer C_buf(Float(32)); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + C->Bind(C_buf); + return std::make_tuple(A, B, C, C_buf); +} + +TEST(CodeGenC, module) { + ir::Tensor A, B, C; + lang::Buffer C_buf(Float(32)); + std::tie(A, B, C, C_buf) = CreateTensor1(); + + LOG(INFO) << "C.body: " << C->get_compute_op()->body.front(); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + Module::Builder builder("module1", target); + + auto stages = CreateStages({A, B, C}); + auto func = Lower("add1", stages, {A, B, C}); + + builder.AddFunction(func); + + { + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "codegen C:" << std::endl << out << std::endl; + + std::string target_str = R"ROC( +#include +#include + +void add1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 100; i += 1) { + for (int32_t j = 0; j < 20; j += 1) { + C[((20 * i) + j)] = (A[((20 * i) + j)] + B[((20 * i) + j)]); + }; + }; + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + EXPECT_EQ(utils::Trim(target_str), utils::Trim(out)); + } + + { + CodeGenC compiler(target); + auto out = compiler.Compile(builder.Build(), CodeGenC::OutputKind::CHeader); + std::cout << "header:\n" << out << std::endl; + auto target_str = R"ROC( +#ifndef _MODULE1_CINN_H_ +#define _MODULE1_CINN_H_ + +#include +#include + +void add1(void* _args, int32_t num_args); + + +#endif // _MODULE1_CINN_H_ +)ROC"; + + EXPECT_EQ(utils::Trim(out), utils::Trim(target_str)); + } + + { + CodeGenC compiler(target); + compiler.SetInlineBuiltinCodes(false); + Outputs outputs; + outputs = outputs.c_header("./generated_module1.h").c_source("./_generated_module1.cc"); + compiler.Compile(builder.Build(), outputs); + } +} + +TEST(CodeGenC, matmul) { + using namespace ir; // NOLINT + Context::Global().ResetNameId(); + + Placeholder A("A", {Expr(100), Expr(20)}); + Placeholder B("B", {Expr(20), Expr(50)}); + + Target target{}; + + Module::Builder builder("module1", target); + + // C = A * B + Var k(20, "k0"); + + Tensor C = Compute( + {Expr(100), Expr(50)}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + auto stages = CreateStages({A, B, C}); + + // Code gen + auto func = Lower("matmul", stages, {A, B, C}); + builder.AddFunction(func); + builder.AddBuffer(C->buffer); + + { // main + std::vector returns({lang::ReturnType{Float(32), C->shape, C->name}}); + + auto tensors = lang::CallLowered("matmul", {A, B}, returns); + + auto C = tensors[0]; + C->WithBuffer(); + + LOG(INFO) << "C.body: " << C->body(); + + auto stages = CreateStages({C}); + + auto f = Lower("main", stages, {A, B, C}, {}); + std::cout << "f\n" << Expr(f) << std::endl; + builder.AddFunction(f); + } + + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "codegen C:" << std::endl << out << std::endl; + + auto tgt = R"ROC( +#include +#include + +void matmul(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C__reduce_init = ((float*)(_C->memory)); + for (int32_t i = 0; i < 100; i += 1) { + for (int32_t j = 0; j < 50; j += 1) { + C__reduce_init[((50 * i) + j)] = 0.00000000f; + for (int32_t k0 = 0; k0 < 20; k0 += 1) { + C[((50 * i) + j)] = (C[((50 * i) + j)] + (A[((20 * i) + k0)] * B[((50 * k0) + j)])); + }; + }; + }; + cinn_buffer_free((void*)(0), _C); +} + +void main(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + cinn_pod_value_t _pod_val_; + buffer_p_to_cinn_pod_value(_A, &_pod_val_); + cinn_pod_value_t _pod_val__0; + buffer_p_to_cinn_pod_value(_B, &_pod_val__0); + cinn_pod_value_t _pod_val__1; + buffer_p_to_cinn_pod_value(_C, &_pod_val__1); + cinn_pod_value_t _pod_arr[3]; + cinn_args_construct(_pod_arr, 3, &_pod_val_, &_pod_val__0, &_pod_val__1); + matmul(_pod_arr, 3); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + + ASSERT_EQ(Trim(tgt), Trim(out)); +} + +// This matches output of competitor. +TEST(CodeGenC, matmul_tile) { + using namespace ir; // NOLINT + Expr M(100); + Expr K(200); + Expr N(500); + Expr bn(32); + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + // C = A * B + Var k(K.as_int32(), "k0"); + + Tensor C_init = Compute( + {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); + + Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + auto stages = CreateStages({C, C_init}); + stages[C]->ShareBufferWith(stages[C_init]); + + { + auto _i_outer_i_inner_j_outer_j_inner_ = stages[C_init]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + stages[C_init]->Reorder({i_outer, j_outer, i_inner, j_inner}); + } + + { + auto _i_outer_i_inner_j_outer_j_inner_ = stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); // NOLINT + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + auto _k_outer_k_inner_ = stages[C]->Split(poly::Iterator("k0"), 4); // NOLINT + auto &k_outer = std::get<0>(_k_outer_k_inner_); + auto &k_inner = std::get<1>(_k_outer_k_inner_); + stages[C]->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner}); + } + + stages[C_init]->ComputeAtSchedule(stages[C], 3, poly::Stage::kComputeAtBefore); + + // Code gen + auto func = Lower("matmul", stages, {A, B, C}); + + Target target = common::DefaultHostTarget(); + + Module::Builder builder("module1", target); + builder.AddFunction(func); + builder.AddBuffer(C_init->buffer); + + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "codegen C:" << std::endl << out << std::endl; + + auto target_out = R"ROC( +#include +#include + +void matmul(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C__reduce_init = ((float*)(_C->memory)); + float* C_init = ((float*)(_C->memory)); + for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { + for (int32_t i_inner = 0; i_inner < cinn_min(32, (100 + (-32 * i_outer))); i_inner += 1) { + for (int32_t j_inner = 0; j_inner < cinn_min(32, (500 + (-32 * j_outer))); j_inner += 1) { + C__reduce_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0.00000000f; + C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0.00000000f; + for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { + for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = fma(A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))], B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))], C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))]); + }; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + + ASSERT_EQ(Trim(target_out), Trim(out)); +} + +TEST(CodeGenC, matmul_packed) { + Expr M(100); + Expr K(200); + Expr N(500); + Expr bn(32); + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + // TODO(Superjomn) Make sure the domain works. + Var k(K.as_int32(), "k0"); + auto packedB = Compute( + {N / bn, K, bn}, [&](Expr x, Expr y, Expr z) { return B(y, x * bn + z); }, "PackedB"); + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return ReduceSum(A(i, k) * packedB(j / bn, k, j % bn), {k}); }, "C"); + + auto stages = CreateStages({packedB, C}); + + { + auto _i_outer_i_inner_j_outer_j_inner_ = stages[C]->Tile(0, 1, bn.as_int32(), bn.as_int32()); + auto &i_outer = std::get<0>(_i_outer_i_inner_j_outer_j_inner_); + auto &i_inner = std::get<1>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_outer = std::get<2>(_i_outer_i_inner_j_outer_j_inner_); + auto &j_inner = std::get<3>(_i_outer_i_inner_j_outer_j_inner_); + auto _k_outer_k_inner_ = stages[C]->Split(poly::Iterator("k0"), 4); + auto &k_outer = std::get<0>(_k_outer_k_inner_); + auto &k_inner = std::get<1>(_k_outer_k_inner_); + stages[C]->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner}); + } + + // Code gen + auto func = Lower("matmul_with_packing", stages, {A, B, packedB, C}); + + Target target = common::DefaultHostTarget(); + + Module::Builder builder("module1", target); + builder.AddFunction(func); + builder.AddBuffer(C->buffer); + builder.AddBuffer(packedB->buffer); + + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "codegen C:" << std::endl << out << std::endl; + + auto target_out = R"ROC( +#include +#include + +void matmul_with_packing(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _PackedB = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[3])); + cinn_buffer_malloc((void*)(0), _PackedB); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C__reduce_init = ((float*)(_C->memory)); + float* PackedB = ((float*)(_PackedB->memory)); + for (int32_t i = 0; i < 15; i += 1) { + for (int32_t j = 0; j < 200; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + PackedB[((6400 * i) + ((32 * j) + k))] = B[((32 * i) + ((500 * j) + k))]; + }; + }; + }; + for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { + for (int32_t i_inner = 0; i_inner < cinn_min(32, (100 + (-32 * i_outer))); i_inner += 1) { + for (int32_t j_inner = 0; j_inner < cinn_min(32, (500 + (-32 * j_outer))); j_inner += 1) { + C__reduce_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0; + for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { + for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { + C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = fma(A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))], PackedB[((6400 * (j_inner / 32)) + ((j_inner & 31) + ((6400 * j_outer) + ((32 * k0_inner) + (128 * k0_outer)))))], C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))]); + }; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _PackedB); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + // ToDo @haoze @wangyue Check Codegen + // ASSERT_EQ(utils::Trim(target_out), utils::Trim(out)); +} + +TEST(CodeGenC, call_extern) { + Expr M(100); + + Placeholder x("x", {M}); + + ir::Tensor y = Compute( + {M}, [=](Var i) -> Expr { return lang::CallExtern("tanh", {x(i)}); }, "y"); + + auto stages = CreateStages({y}); + + auto yexpr = Lower("yy", stages, {y}); + + Module::Builder builder("module0", common::DefaultHostTarget()); + builder.AddFunction(yexpr); + + CodeGenC codegen(common::DefaultHostTarget()); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "codegen C:" << std::endl << out << std::endl; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_c_x86.cc b/paddle/cinn/backends/codegen_c_x86.cc new file mode 100644 index 0000000000000..737566dc2c651 --- /dev/null +++ b/paddle/cinn/backends/codegen_c_x86.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_c_x86.h" + +namespace cinn { +namespace backends { + +void CodeGenCX86::Visit(const ir::Add *op) { VisitBinaryOp(op, op->a(), op->b(), "add"); } +void CodeGenCX86::Visit(const ir::Sub *op) { VisitBinaryOp(op, op->a(), op->b(), "sub"); } +void CodeGenCX86::Visit(const ir::Mul *op) { VisitBinaryOp(op, op->a(), op->b(), "mul"); } +void CodeGenCX86::Visit(const ir::Div *op) { VisitBinaryOp(op, op->a(), op->b(), "div"); } + +void CodeGenCX86::Visit(const ir::Load *op) { + Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); + if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address. + CHECK(op->type().is_vector()); + + int bits = op->type().bits() * op->type().lanes(); + if (SupportsAVX512() && bits == 512) { + os() << "cinn_avx512_load("; + PrintAbsAddr(op); + os() << ")"; + } else if (SupportsAVX256() && bits == 256) { + os() << "cinn_avx256_load("; + PrintAbsAddr(op); + os() << ")"; + } else { + CodeGenC::Visit(op); + } + } else { + CodeGenC::Visit(op); + } +} + +void CodeGenCX86::Visit(const ir::Broadcast *op) { + CHECK_GT(op->type().lanes(), 1); + int bits = op->type().bits() * op->type().lanes(); + + if (SupportsAVX512() && bits == 512) { + os() << "cinn_avx512_set1("; + PrintCastExpr(op->value.type().ElementOf(), op->value); + os() << ")"; + } else if (SupportsAVX256() && bits == 256) { + os() << "cinn_avx256_set1("; + PrintCastExpr(op->value.type().ElementOf(), op->value); + os() << ")"; + } else { + CodeGenC::Visit(op); + } +} + +void CodeGenCX86::Visit(const ir::Store *op) { + if (op->type().lanes() == 1) { + CodeGenC::Visit(op); + return; + } + + int bits = op->type().bits() * op->type().lanes(); + if (SupportsAVX512() && bits == 512) { + os() << "cinn_avx512_store("; + PrintAbsAddr(op); + os() << ", "; + Print(op->value); + os() << ")"; + } else if (SupportsAVX256() && bits == 256) { + os() << "cinn_avx256_store("; + PrintAbsAddr(op); + os() << ", "; + Print(op->value); + os() << ")"; + } else { + CodeGenC::Visit(op); + } +} + +void CodeGenCX86::PrintVecInputArgument(const Expr *op) { + int bits = op->type().bits() * op->type().lanes(); + auto *broadcast_n = op->As(); + + if (op->type().lanes() == 1 || broadcast_n) { + Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value; + + if (SupportsAVX512()) { + os() << "cinn_avx512_set1("; + Print(value); + os() << ")"; + } else if (SupportsAVX256()) { + os() << "cinn_avx256_set1("; + Print(value); + os() << ")"; + } else { + CINN_NOT_IMPLEMENTED + } + } else { + Print(*op); + } +} + +void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) { + if (op->type().lanes() == 1) { + CodeGenC::Visit(op); + return; + } + int bits = op->type().bits() * op->type().lanes(); + if (SupportsAVX512() && bits == 512) { + os() << "cinn_avx512_" << op->name << "("; + if (!op->args.empty()) { + for (int i = 0; i < op->args.size() - 1; i++) { + PrintVecInputArgument(&op->args[i]); + os() << ", "; + } + Print(op->args.back()); + } + os() << ")"; + } else if (SupportsAVX256() && bits == 256) { + os() << "cinn_avx256_" << op->name << "("; + if (!op->args.empty()) { + for (int i = 0; i < op->args.size() - 1; i++) { + PrintVecInputArgument(&op->args[i]); + os() << ", "; + } + PrintVecInputArgument(&op->args.back()); + } + os() << ")"; + } else if (bits == 128) { + os() << "cinn_avx128_" << op->name << "("; + if (!op->args.empty()) { + for (int i = 0; i < op->args.size() - 1; i++) { + PrintVecInputArgument(&op->args[i]); + os() << ", "; + } + PrintVecInputArgument(&op->args.back()); + } + os() << ")"; + } else { + CodeGenC::Visit(op); + } +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_c_x86.h b/paddle/cinn/backends/codegen_c_x86.h new file mode 100644 index 0000000000000..29555df3c5e9a --- /dev/null +++ b/paddle/cinn/backends/codegen_c_x86.h @@ -0,0 +1,131 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/ir/intrinsic_ops.h" + +namespace cinn { +namespace backends { + +/** + * C code generation with X86 instruction or math library support. + */ +class CodeGenCX86 : public CodeGenC { + public: + //! The X86 CPU supports some following features. We use SSE or AVX to accelerate the basic operations if forloop is + //! vectorized. + enum class Feature : int { + None = 0, + SSE = 1, //! support SSE instruction set. + AVX256 = 1 << 1, // ! support AVX256 instruction set. + AVX512 = 1 << 2, // ! support AVX512 instruction set. + BLAS = 1 << 3, // ! support BLAS library. + }; + + Feature feature{Feature::None}; + + /** + * constructor. + * @param target The device. + * @param features Features it supported. + */ + CodeGenCX86(Target target, Feature feature) : CodeGenC(target), feature(feature) {} + + protected: + void Visit(const ir::Add *op) override; + void Visit(const ir::Sub *op) override; + void Visit(const ir::Mul *op) override; + void Visit(const ir::Div *op) override; + void Visit(const ir::Mod *op) override { CodeGenC::Visit(op); } + void Visit(const ir::EQ *op) override { CodeGenC::Visit(op); } + void Visit(const ir::NE *op) override { CodeGenC::Visit(op); } + void Visit(const ir::LT *op) override { CodeGenC::Visit(op); } + void Visit(const ir::LE *op) override { CodeGenC::Visit(op); } + void Visit(const ir::GT *op) override { CodeGenC::Visit(op); } + void Visit(const ir::GE *op) override { CodeGenC::Visit(op); } + void Visit(const ir::And *op) override { CodeGenC::Visit(op); } + void Visit(const ir::Or *op) override { CodeGenC::Visit(op); } + void Visit(const ir::Load *op) override; + void Visit(const ir::Store *op) override; + void Visit(const ir::Broadcast *op) override; + void Visit(const ir::intrinsics::BuiltinIntrin *op); + + //! Check the features. + // @{ + bool SupportsSSE() { return static_cast(feature) & static_cast(Feature::SSE); } + bool SupportsAVX256() { return static_cast(feature) & static_cast(Feature::AVX256); } + bool SupportsAVX512() { return static_cast(feature) & static_cast(Feature::AVX512); } + bool SupportsBLAS() { return static_cast(feature) & static_cast(Feature::BLAS); } + // @} + + //! Print (and prepare) a argument in vectorize type, for example: + // 3. -> set1(3.) + // a[i:j] -> load_ps(a+i) + void PrintVecInputArgument(const Expr *op); + //! The output argument, such as the destination for Load. + void PrintVecOutputArgument(const Expr *op); + + template + void PrintAbsAddr(const Op *op) { + os() << op->tensor.template As()->name << " + "; + + auto index = op->index(); + auto *ramp_n = index.template As(); + if (ramp_n) { + CHECK(!ramp_n->base.template As()) << "base of a Ramp node should not be Ramp type"; + Print(ramp_n->base); + } else { + Print(op->index()); + } + } + + template + void VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string &op_repr); +}; + +template +void CodeGenCX86::VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string &op_repr) { + CHECK_EQ(a.type(), b.type()) << " a is : " << a << ", and b is : " << b << ". op_repr is : " << op_repr; + + // scalar. + if (a.type().lanes() == 1) { + CodeGenC::Visit(op); + return; + } + + // TODO(Superjomn) Consider support BLAS. + int bits = a.type().bits() * a.type().lanes(); + if (SupportsAVX512() && bits == 512) { + os() << "cinn_avx512_" << op_repr << "("; + PrintVecInputArgument(&a); + os() << ", "; + PrintVecInputArgument(&b); + os() << ")"; + } else if (SupportsAVX256() && bits == 256) { + os() << "cinn_avx256_" << op_repr << "("; + PrintVecInputArgument(&a); + os() << ", "; + PrintVecInputArgument(&b); + os() << ")"; + } else { + CodeGenC::Visit(op); + } +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_c_x86_test.cc b/paddle/cinn/backends/codegen_c_x86_test.cc new file mode 100644 index 0000000000000..b4cb6bf376a51 --- /dev/null +++ b/paddle/cinn/backends/codegen_c_x86_test.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_c_x86.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/module.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/transform_polyfor_to_for.h" +#include "cinn/optim/vectorize_loops.h" + +namespace cinn { +namespace backends { + +TEST(CodeGenCX86, basic) { + // create two forloops, check only one forloop is marked Vectorize. + Context::info_rgt().Clear(); + + using namespace ir; // NOLINT + + const int M = 100; + const int K = 200; + const int N = 500; + const int bn = 32; + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // C = A * B + Tensor C = Compute( + {Expr(M), Expr(N)}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C"); + + Tensor D = Compute( + {Expr(M), Expr(N)}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "D"); + + auto stages = CreateStages({C, D}); + // vectorize C, not D + stages[C]->Vectorize(1, 16); + stages[C]->Unroll(1); + + auto func = Lower("matmul", stages, {A, B, C, D}); + + std::cout << "before optim\n" << func->body << std::endl; + + ir::Module::Builder builder("module1", target); + builder.AddFunction(func); + + CodeGenCX86 codegen(target, CodeGenCX86::Feature::AVX512); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "out:\n" << out; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_dev.cc b/paddle/cinn/backends/codegen_cuda_dev.cc new file mode 100644 index 0000000000000..21fc8961faeea --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_dev.cc @@ -0,0 +1,391 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_cuda_dev.h" + +#include +#include + +#include +#include +#include + +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_verify.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/remove_nested_block.h" + +namespace cinn { +namespace backends { + +const std::string CodeGenCUDA_Dev::source_header_ = + R"(#include + +#define CINN_WITH_CUDA +#include "bfloat16.h" +#include "float16.h" +using cinn::common::bfloat16; +using cinn::common::float16; +using cinn::common::half4; +using cinn::common::half8; +using cinn::common::float8; + +#include "cinn_cuda_runtime_source.cuh" +)"; + +const std::string &CodeGenCUDA_Dev::GetSourceHeader() { return source_header_; } + +CodeGenCUDA_Dev::CodeGenCUDA_Dev(Target target) : CodeGenC(target) {} + +std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, bool for_nvrtc) { + for_nvrtc_ = for_nvrtc; + auto source = Compile(module, OutputKind::CImpl); + + return source; +} + +void CodeGenCUDA_Dev::Compile(const ir::Module &module, const Outputs &outputs) { + ir::IrVerify(Expr(module)); + + CodeGenC::inline_builtin_codes_ = false; + if (!outputs.c_header_name.empty()) { + auto source = Compile(module, OutputKind::CHeader); + std::ofstream file(outputs.c_header_name); + CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; + file << source; + file.close(); + LOG(WARNING) << "Output C header to file " << outputs.c_header_name; + } + + if (!outputs.cuda_source_name.empty()) { + auto source = Compile(module, OutputKind::CImpl); + std::ofstream file(outputs.cuda_source_name); + CHECK(file.is_open()) << "failed to open file " << outputs.cuda_source_name; + file << source; + file.close(); + LOG(WARNING) << "Output C source to file " << outputs.cuda_source_name; + } +} + +std::string CodeGenCUDA_Dev::Compile(const ir::LoweredFunc &func) { + Print(Expr(func)); + return ss_.str(); +} + +std::vector CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFunc_ *op, + const std::vector &temp_buffers) { + std::set temp_buffer_set(temp_buffers.begin(), temp_buffers.end()); + // prepare temp buffer alias + std::vector buffer_alias; + auto tensors = ir::CollectIRNodes(op->body, [&](const Expr *x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && temp_buffer_set.count(x->as_tensor()->buffer); + }); + + // unique tensors + std::set unique_tensors; + for (auto &e : tensors) { + unique_tensors.insert(e.as_tensor_ref()); + } + + for (auto &t : unique_tensors) { + auto data_type = t->type(); + auto data_ptr_type = data_type; + data_ptr_type.set_cpp_handle(); + + Var t_var(t->name, data_ptr_type); + Var buf_var(t->buffer->name, data_ptr_type); + buffer_alias.push_back(ir::Let::Make(t_var, buf_var)); + } + + return buffer_alias; +} + +void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) { + // clear names valid within scope when enter a new function + vectorized_tensor_names_.clear(); + os() << "__global__\n"; + + PrintFunctionDeclaration(op); + os() << "\n"; + + DoIndent(); + + std::vector new_body; + + auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); + auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs); + auto alis_var_exprs = op->CudaAliasVarExprs(); + +#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(field__), std::end(field__)); + APPEND_TO_NEW_BODY(alloca_temp_buffers) + APPEND_TO_NEW_BODY(temp_buffer_alias) + APPEND_TO_NEW_BODY(alis_var_exprs) + + new_body.push_back(op->body); + + Expr func_body = ir::Block::Make(new_body); + + optim::RemoveNestedBlock(&func_body); + // Make sure that the function's body is wrapped by a block + if (!func_body.As()) { + func_body = ir::Block::Make({func_body}); + } + Print(func_body); +} + +void CodeGenCUDA_Dev::Visit(const ir::_Var_ *op) { + if (utils::Startswith(op->name, "threadIdx") || utils::Startswith(op->name, "blockIdx")) { + os() << "(int)" + op->name; + } else { + os() << op->name; + } +} + +void CodeGenCUDA_Dev::Visit(const ir::Alloc *op) { + CHECK(op->destination.as_buffer()); + PrintTempBufferCreation(op->destination.as_buffer_ref()); +} + +void CodeGenCUDA_Dev::Visit(const ir::Min *op) { + os() << "min("; + Print(op->a()); + os() << ", "; + Print(op->b()); + os() << ")"; +} + +void CodeGenCUDA_Dev::Visit(const ir::Max *op) { + os() << "max("; + Print(op->a()); + os() << ", "; + Print(op->b()); + os() << ")"; +} + +void CodeGenCUDA_Dev::PrintFunctionDeclaration(const ir::_LoweredFunc_ *op) { + os() << "void "; + if (op->cuda_axis_info.valid()) { + int thread_num = 1; + for (int i = 0; i < 3; i++) { + thread_num *= op->cuda_axis_info.block_dim(i); + } + os() << "__launch_bounds__(" << thread_num << ") "; + } + + os() << op->name << "("; + for (int i = 0; i < op->args.size() - 1; i++) { + auto &arg = op->args[i]; + PrintFuncArg(arg); + os() << ", "; + } + if (!op->args.empty()) { + PrintFuncArg(op->args.back()); + } + os() << ")"; +} + +void CodeGenCUDA_Dev::PrintFuncArg(const ir::Argument &arg) { + if (arg.is_buffer()) { + // In CUDA kernel, only primitive type is supported, so we replace the buffer with T*j + if (arg.is_input()) os() << "const "; + os() << GetTypeRepr(arg.buffer_arg()->dtype); + os() << "* "; + os() << kCKeywordRestrict << " "; + os() << ir::BufferGetTensorName(arg.buffer_arg().As()); + } else if (arg.is_var()) { + if (arg.var_arg()->type().is_cpp_handle()) { + os() << kCKeywordRestrict; + } + os() << GetTypeRepr(arg.type()) << " "; + os() << arg.name(); + } else { + CINN_NOT_IMPLEMENTED + } +} + +void CodeGenCUDA_Dev::PrintBuiltinCodes() {} + +std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, CodeGenC::OutputKind output_kind) { + if (output_kind == OutputKind::CHeader) { + GenerateHeaderFile(module); + } else if (output_kind == OutputKind::CImpl) { + PrintIncludes(); + + if (for_nvrtc_) { + os() << "\nextern \"C\" {\n\n"; + } + + PrintBuiltinCodes(); + + for (auto &func : module.functions()) { + Compile(func); + } + } else { + LOG(FATAL) << "Not supported OutputKind"; + } + + if (for_nvrtc_) { + os() << "\n\n}"; + } + + return ss_.str(); +} + +void CodeGenCUDA_Dev::PrintIncludes() { os() << GetSourceHeader(); } + +void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { + CHECK_NE(buffer->type(), Void()); + auto print_gpu_memory = [&](const std::string &mark) { + os() << mark << GetTypeRepr(buffer->dtype) << " " << buffer->name << " "; + + os() << "[ "; + Expr buffer_size(1); + for (int i = 0; i < buffer->shape.size(); i++) { + buffer_size = buffer_size * buffer->shape[i]; + } + optim::Simplify(&buffer_size); + Print(buffer_size); + os() << " ]"; + }; + switch (buffer->memory_type) { + case ir::MemoryType::GPUShared: + print_gpu_memory("__shared__ "); + break; + + case ir::MemoryType::GPULocal: + print_gpu_memory(""); + break; + + default: + LOG(FATAL) << "CUDA device codegen not support memory " << buffer->name << ", type " << buffer->memory_type; + } +} + +void CodeGenCUDA_Dev::Visit(const ir::Call *op) { + os() << op->name + "("; + + if (!op->read_args.empty()) { + for (int i = 0; i < op->read_args.size() - 1; i++) { + auto &arg = op->read_args[i]; + if (arg.as_tensor()) { + os() << arg.as_tensor()->name; + os() << ", "; + } else { + Print(arg); + os() << ", "; + } + } + if (op->read_args.back().as_tensor()) { + os() << op->read_args.back().as_tensor()->name; + } else { + Print(op->read_args.back()); + } + } + + if (!op->write_args.empty()) { + os() << ", "; + for (int i = 0; i < op->write_args.size() - 1; i++) { + auto &arg = op->write_args[i]; + if (arg.as_tensor()) { + os() << arg.as_tensor()->name; + os() << ", "; + } else { + Print(arg); + os() << ", "; + } + } + if (op->write_args.back().as_tensor()) { + os() << op->write_args.back().as_tensor()->name; + } else { + Print(op->write_args.back()); + } + } + + os() << ")"; +} + +void CodeGenCUDA_Dev::Visit(const ir::Let *op) { + CHECK(op->type().valid()); + + // identify vectorized tensors by checking their dtypes are customized_type + // with customized_type::kcuda_builtin_vector_t prefix, and save their names + if (op->type().is_customized() && + utils::Startswith(op->type().customized_type(), common::customized_type::kcuda_builtin_vector_t)) { + os() << GetTypeRepr(op->type()); + if (op->type().is_cpp_handle()) { + os() << " " << kCKeywordRestrict; + } + os() << " "; + Print(op->symbol); + vectorized_tensor_names_.insert(utils::GetStreamCnt(op->symbol)); + // skip "=0" in "half8 temp = 0;" sincethe operator= of half8 may not overloaded. + if (op->body.As() && op->body.As()->value == 0) { + return; + } + os() << " = "; + Print(op->body); + } else { + CodeGenC::Visit(op); + } +} + +bool CodeGenCUDA_Dev::PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger *op, ir::Expr index_expr, bool is_store) { + static constexpr char index2suffix[8] = {'x', 'y', 'z', 'w', 'v', 'u', 't', 's'}; + + // addr of op should be a place of tensor and the index is simple int number + if (!op->is_addr_tensor() || !index_expr.As()) { + return false; + } + auto *tensor = op->tensor.As(); + CHECK(tensor); + + // identify vectorized tensors by their names + if (!vectorized_tensor_names_.count(tensor->name)) { + return false; + } + + // the index can't exceed the range of cuda built-in vector type + int index = index_expr.As()->value; + if (index < 0 || index >= 8) { + return false; + } + if (is_store && tensor->type().is_cpp_handle()) { + os() << tensor->name << "[" << index << "]"; + } else { + os() << tensor->name << (tensor->type().is_cpp_handle() ? "->" : ".") << index2suffix[index]; + } + return true; +} + +void CodeGenCUDA_Dev::Visit(const ir::Load *op) { + // overload this visit function to especially deal with the case when it accesses + // element at a cuda built-in vector, others still resolve to CodeGenC + if (!PrintBuiltinVectorAccess(op, op->index(), false)) { + CodeGenC::Visit(op); + } +} + +void CodeGenCUDA_Dev::Visit(const ir::Store *op) { + // overload this visit function to especially deal with the case when it accesses + // element at a cuda built-in vector, others still resolve to CodeGenC + if (PrintBuiltinVectorAccess(op, op->index(), true)) { + os() << " = "; + Print(op->value); + } else { + CodeGenC::Visit(op); + } +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_dev.h b/paddle/cinn/backends/codegen_cuda_dev.h new file mode 100644 index 0000000000000..ad7e03024553f --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_dev.h @@ -0,0 +1,110 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/common/common.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/module.h" +#include "cinn/lang/packed_func.h" +#include "cinn/runtime/cinn_runtime.h" + +namespace cinn::ir { +class Module; +} // namespace cinn::ir + +namespace cinn { +namespace backends { + +/** + * CUDA device code generator. + * + * It generates the device function, e.g, the function called "myadd" will have a __global__ functon called + * "myadd_kernel", different from codegen_c, the declaration of the "myadd_kernel" function has an expanded argument + * list, which finally similar to `__global__ void myadd(float* __restrict__ A, float* __restrict__ B, int n);` + */ +class CodeGenCUDA_Dev : public CodeGenC { + public: + explicit CodeGenCUDA_Dev(Target target); + + /** + * Compile the \p module to \p outputs. + */ + void Compile(const ir::Module& module, const Outputs& outputs); + + //! Compile on NVRTC. + std::string Compile(const ir::Module& module, bool for_nvrtc = true); + + std::string Compile(const ir::LoweredFunc& func); + + /** + * \brief Print a function argument in CUDA syntax. Currently, just some decoration of __restrict__. + * @param arg the argument. + * @return the representation in CUDA syntax. + * + * We make it a static to make the test easier. + */ + void PrintFuncArg(const ir::Argument& arg); + + std::string Compile(const ir::Module& module, OutputKind output_kind); + + static const std::string& GetSourceHeader(); + + protected: + void Visit(const ir::_Var_* op) override; + void Visit(const ir::_LoweredFunc_* op) override; + void Visit(const ir::Min* op) override; + void Visit(const ir::Max* op) override; + void Visit(const ir::Alloc* op) override; + void Visit(const ir::Call* op) override; + void Visit(const ir::Load* op) override; + void Visit(const ir::Store* op) override; + void Visit(const ir::Let* op) override; + + // Print element access at a cuda built-in vector on a load/store node + bool PrintBuiltinVectorAccess(const ir::LoadStoreAddrMnger* op, ir::Expr index, bool is_store); + + void PrintBuiltinCodes(); + + void PrintIncludes() override; + + void PrintTempBufferCreation(const ir::Buffer& buffer); + + void PrintTempBufferAliasDefinition(const ir::Buffer& buffer); + + std::vector GenerateBufferAliasExprs(const ir::_LoweredFunc_* op, const std::vector& temp_buffers); + + /** + * Print the function declaration, this is different from C, we expand the arguments and get something like + * `__global__ void myadd(float* __restrict__ A, float* __restrict__ B, int n);` + */ + void PrintFunctionDeclaration(const ir::_LoweredFunc_* op); + + private: + Target target_; + bool for_nvrtc_{false}; + // names of vectorized tensors from `Let` statments where dtypes of the tensors + // are customized_type with customized_type::kcuda_builtin_vector_t prefix + std::unordered_set vectorized_tensor_names_; + static const std::string source_header_; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_generate_test.cc b/paddle/cinn/backends/codegen_cuda_generate_test.cc new file mode 100644 index 0000000000000..5d4fc35afe663 --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_generate_test.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include + +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/backends/codegen_cuda_host.h" +#include "cinn/backends/codegen_cuda_util.h" +#include "cinn/backends/extern_func_jit_register.h" +#include "cinn/backends/llvm/execution_engine.h" +#include "cinn/backends/llvm/simple_jit.h" +#include "cinn/cinn.h" +#include "cinn/common/ir_util.h" +#include "cinn/common/test_helper.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/lang/lower.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/utils/timer.h" + +namespace cinn { +namespace backends { + +TEST(CUDAFile, Module_output) { + std::string cuda_source_name = "_generated1.cu"; + std::string cuda_source_code = R"ROC( +extern "C" { + +__global__ +void __launch_bounds__(200) elementwise_mul(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C) +{ + if (((int)blockIdx.x < 100)) { + if (((int)threadIdx.x < 200)) { + C[((200 * (int)blockIdx.x) + (int)threadIdx.x)] = (A[((200 * (int)blockIdx.x) + (int)threadIdx.x)] * B[((200 * (int)blockIdx.x) + (int)threadIdx.x)]); + }; + }; +} + +} + )ROC"; + std::ofstream file(cuda_source_name); + CHECK(file.is_open()) << "failed to open file " << cuda_source_name; + file << CodeGenCUDA_Dev::GetSourceHeader(); + file << cuda_source_code; + file.close(); + LOG(WARNING) << "Output C source to file " << cuda_source_name; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_host.cc b/paddle/cinn/backends/codegen_cuda_host.cc new file mode 100644 index 0000000000000..38774b181dbcc --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_host.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_cuda_host.h" + +#include +#include +#include + +#include "cinn/backends/codegen_cuda_util.h" +#include "cinn/backends/extern_func_emitter_builtin.h" +#include "cinn/backends/extern_func_jit_register.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace backends { + +using cinn::common::bfloat16; +using cinn::common::float16; + +const int kArgsArrayMaxLen = 20; + +llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* func) { + auto body = func->body; + auto* call_ir = body.As(); + CHECK(call_ir); + + // Create the function + // @{ + auto* function_type = GenFunctionTypeFromCinnFunction(func, true); + llvm::Function* function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, func->name, m_); + function->setCallingConv(llvm::CallingConv::C); + function->setHasUWTable(); + + std::vector ll_function_args; + std::transform(function->arg_begin(), function->arg_end(), std::back_inserter(ll_function_args), [](auto& arg) { + return std::addressof(arg); + }); + // @} + + llvm::BasicBlock* entry = llvm::BasicBlock::Create( + /*Context=*/b_->getContext(), + /*Name=*/"entry", + /*Parent=*/function, + /*InsertBefore=*/nullptr); + b_->SetInsertPoint(entry); + + auto* kernel_args = ll_function_args[0]; + auto* kernel_args_count = ll_function_args[1]; + llvm::Value* kernel_stream = nullptr; + if (ll_function_args.size() == 3) { + kernel_stream = ll_function_args[2]; + CHECK_EQ(kernel_stream->getType(), ll_void_p_ty()); // void* stream + } + CHECK_EQ(kernel_args->getType(), ll_void_p_ty()); // void* args + CHECK_EQ(kernel_args_count->getType(), ll_int32_ty()); // int32 + + std::unordered_map global_args = { + {KERNEL_ARGS, kernel_args}, {KERNEL_ARGS_NUM, kernel_args_count}, {KERNEL_STREAM, kernel_stream}}; + + auto ret_type = CinnTypeToLLVMType(Void(), m_); + std::vector args_type; + for (auto r_arg : call_ir->read_args) { + if (r_arg.is_var()) { + if (r_arg.as_var()->type().is_cpp_handle() || r_arg.as_var()->type().is_string()) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.as_var()->type().is_int(32)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else { + CINN_NOT_IMPLEMENTED; + } + } else { + if (r_arg.type().is_bool()) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_uint(8)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_uint(16)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_uint(32)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_uint(64)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_int(8)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_int(16)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_int(32)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_int(64)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_float(32)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_float(64)) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_bfloat16()) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else if (r_arg.type().is_float16()) { + args_type.push_back(CinnTypeToLLVMType(type_of(), m_)); + } else { + CINN_NOT_IMPLEMENTED; + } + } + } + auto func_type = llvm::FunctionType::get(ret_type, args_type, false); + auto call_func = m_->getOrInsertFunction(call_ir->name, func_type); + + std::vector call_args; + for (auto& r_arg : call_ir->read_args) { + if (r_arg.is_var()) { + if (r_arg.as_var()->type().is_string()) { + auto kvalue = m_->getOrInsertGlobal(r_arg.as_var()->name + "_ptr_", b_->getInt8PtrTy()); + call_args.push_back(b_->CreateLoad(b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load")); + } else if (r_arg.as_var()->type().is_cpp_handle() || r_arg.as_var()->type().is_int(32)) { + CHECK(global_args.count(r_arg.as_var()->name)); + call_args.push_back(global_args[r_arg.as_var()->name]); + } else { + CINN_NOT_IMPLEMENTED; + } + } else { + if (r_arg.type().is_bool()) { + call_args.push_back(b_->getInt1(r_arg.as_bool())); + } else if (r_arg.type().is_int(8)) { + call_args.push_back(b_->getInt8(r_arg.as_int8())); + } else if (r_arg.type().is_int(16)) { + call_args.push_back(b_->getInt16(r_arg.as_int16())); + } else if (r_arg.type().is_int(32)) { + call_args.push_back(b_->getInt32(r_arg.as_int32())); + } else if (r_arg.type().is_int(64)) { + call_args.push_back(b_->getInt64(r_arg.as_int64())); + } else if (r_arg.type().is_uint(8)) { + call_args.push_back(b_->getInt8(r_arg.as_uint8())); + } else if (r_arg.type().is_uint(16)) { + call_args.push_back(b_->getInt16(r_arg.as_uint16())); + } else if (r_arg.type().is_uint(32)) { + call_args.push_back(b_->getInt32(r_arg.as_uint32())); + } else if (r_arg.type().is_uint(64)) { + call_args.push_back(b_->getInt64(r_arg.as_uint64())); + } else if (r_arg.type().is_float(32)) { + call_args.push_back(llvm::ConstantFP::get(b_->getFloatTy(), llvm::APFloat(r_arg.as_float()))); + } else if (r_arg.type().is_float(64)) { + call_args.push_back(llvm::ConstantFP::get(b_->getDoubleTy(), llvm::APFloat(r_arg.as_double()))); + } else if (r_arg.type().is_bfloat16()) { + call_args.push_back( + llvm::ConstantFP::get(b_->getBFloatTy(), llvm::APFloat(static_cast(r_arg.as_bfloat16())))); + } else if (r_arg.type().is_float16()) { + call_args.push_back( + llvm::ConstantFP::get(b_->getHalfTy(), llvm::APFloat(static_cast(r_arg.as_float16())))); + } else { + CINN_NOT_IMPLEMENTED; + } + } + } + + b_->CreateCall(call_func, call_args); + RetVoid(); + + return function; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_host.h b/paddle/cinn/backends/codegen_cuda_host.h new file mode 100644 index 0000000000000..4f0b858db4144 --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_host.h @@ -0,0 +1,56 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "cinn/backends/llvm/codegen_llvm.h" + +namespace cinn { +namespace backends { + +/** + * CodeGenCUDA takes a CINN Module with host functions and output a LLVM module. + */ +class CodeGenCUDA_Host : public CodeGenLLVM { + public: + explicit CodeGenCUDA_Host(llvm::Module *m, llvm::IRBuilder<> *b, const std::shared_ptr &vars = nullptr) + : CodeGenLLVM(m, b, vars) {} + + using CodeGenLLVM::Visit; + llvm::Value *Visit(const ir::_LoweredFunc_ *func) override { return LowerGPUKernelLauncher(func); } + + private: + /** + * Lower a CUDA kernel launcher. + * + * We launch a CUDA kernel in the following way: + * + * 1. a GPU function (called fn) will compiled to PTX and lower by CUDA driver to a function pointer, which we store + * as a `void*` type global variable [fn_kernel_ptr] in LLVM module. + * 2. when lower the host launcher, we replace the Call of the original kernel [fn] to a Call of + * `cinn_call_cuda_kernel` method which is registered as an external function. + * + */ + llvm::Value *LowerGPUKernelLauncher(const ir::_LoweredFunc_ *func); +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_util.cc b/paddle/cinn/backends/codegen_cuda_util.cc new file mode 100644 index 0000000000000..ee7174be9f407 --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_util.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/codegen_cuda_util.h" + +#include "cinn/backends/cuda_util.h" +#include "cinn/ir/ir_mutator.h" + +namespace cinn { +namespace backends { + +std::tuple SplitCudaAndHostModule(ir::Module module) { + detail::CollectHostFunctionVisitor visitor(module->name); + Expr expr(module); + return visitor(&expr); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_util.h b/paddle/cinn/backends/codegen_cuda_util.h new file mode 100755 index 0000000000000..598feede403ae --- /dev/null +++ b/paddle/cinn/backends/codegen_cuda_util.h @@ -0,0 +1,140 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/optim/ir_copy.h" + +namespace cinn { +namespace backends { + +#define KERNEL_ARGS "kernel_args" +#define KERNEL_ARGS_NUM "kernel_args_num" +#define KERNEL_STREAM "kernel_stream" + +/** + * Split a CINN Module into two separate modules, one cantains the host functions, the other contains the device + * kernels. + * + * This contains some process: + * + * - replace the original kernel function with a Call node and add it to the first module, add a device kernel function + * to the second module. + */ +std::tuple SplitCudaAndHostModule(ir::Module module); + +namespace detail { + +struct CollectHostFunctionVisitor : public ir::IRMutator<> { + explicit CollectHostFunctionVisitor(const std::string& module_name) + : host_module_builder(module_name + "_host", common::DefaultHostTarget()), + device_module_builder(module_name + "_gpu_device", common::DefaultNVGPUTarget()) {} + + std::tuple operator()(Expr* expr) { + ir::IRMutator<>::Visit(expr, expr); + return std::make_tuple(host_module_builder.Build(), device_module_builder.Build()); + } + + private: + void Visit(const ir::_LoweredFunc_* op, Expr* expr) override { + if (op->body.As()) { + host_module_builder.AddFunction(expr->as_lowered_func_ref()); + } else { + if (!op->cuda_axis_info.valid()) { + expr->as_lowered_func_ref()->cuda_axis_info.set_valid(true); + } + auto host_func = CreateHostFunctionGivenDeviceKernel(op); + host_module_builder.AddFunction(host_func.as_lowered_func_ref()); + device_module_builder.AddFunction(CreateDeviceFunctionGivenDeviceKernel(*expr).as_lowered_func_ref()); + } + } + + /** + * Create a wrapper function for a kernel. + * + * For example, we get a kernel function: + * + * \code + * __global__ + * void fn (float* a, float* out) { ... } + * \endcode + * + * A host wrapper function will generate for it + * + * \code + * void fn (cinn_buffer_t* a, cinn_buffer_t* out) { + * Call(fn_kernel); + * } + * \endcode + */ + Expr CreateHostFunctionGivenDeviceKernel(const ir::_LoweredFunc_* func) { + // std::vector args; + // NOTE the suffix `__ptr` makes this argument lower to a pointer in LLVM backend. + // args.push_back(Var("args__ptr", type_of())); + // args.push_back(Var("num_args", type_of())); + ir::Var kernel_ptr(GenDeviceKernelName(func->name), type_of()); + ir::Var kernel_args(KERNEL_ARGS, type_of()); + ir::Var kernel_args_num(KERNEL_ARGS_NUM, type_of()); + ir::Var kernel_stream(KERNEL_STREAM, type_of()); + + auto call_extern_api = ir::Call::Make(Void(), + runtime::intrinsic::call_cuda_kernel, + {kernel_ptr, + kernel_args, + kernel_args_num, + Expr(func->cuda_axis_info.grid_dim(0)), // grid_x + Expr(func->cuda_axis_info.grid_dim(1)), // grid_y + Expr(func->cuda_axis_info.grid_dim(2)), // grid_z + Expr(func->cuda_axis_info.block_dim(0)), // block_x + Expr(func->cuda_axis_info.block_dim(1)), // block_y + Expr(func->cuda_axis_info.block_dim(2)), // block_z + kernel_stream}, + {}, + ir::CallType::Extern, + ir::FunctionRef(), + 0); + std::vector arguments = {ir::Argument(kernel_args, ir::Argument::IO::kOutput), + ir::Argument(kernel_args_num, ir::Argument::IO::kInput), + ir::Argument(kernel_stream, ir::Argument::IO::kOutput)}; + + return ir::_LoweredFunc_::Make(func->name, arguments, call_extern_api, {}); + } + + Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) { + auto copied = optim::IRCopy(expr); + auto* lowered_func = copied.as_lowered_func(); + lowered_func->name = GenDeviceKernelName(lowered_func->name); + return copied; + } + + inline std::string GenDeviceKernelName(const std::string& fn) { return fn + "_kernel"; } + + private: + ir::Module::Builder host_module_builder; + ir::Module::Builder device_module_builder; +}; + +} // namespace detail + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/codegen_debug_test.cc b/paddle/cinn/backends/codegen_debug_test.cc new file mode 100644 index 0000000000000..306238f58fe52 --- /dev/null +++ b/paddle/cinn/backends/codegen_debug_test.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/backends/nvrtc/nvrtc_util.h" +#include "cinn/common/context.h" +#include "cinn/runtime/cuda/cuda_module.h" + +namespace cinn { +namespace backends { + +/** + * This file is not a common test, it is used as a util for developers to + * write source CUDA code to debug whether it runs correctly during runtime + */ +using runtime::cuda::CUDAModule; + +/** + * Utility function to create cuda memory of non-empty shape. + * + * @param shape: a non-empty shape for the created cuda memory + * @param data: the data to initialize the cuda memory. Function doesn't + * initailize if it is nullptr + * @return the CUdeviceptr pointing to the created memory + */ +template +CUdeviceptr CreateCudaMemory(const std::vector& shape, const T* data) { + CHECK(!shape.empty()) << "Couldn't create CUDA memory for empty shape"; + CUDA_CALL(cudaDeviceSynchronize()); + + int numel = 1; + for (int s : shape) { + numel = numel * s; + } + + CUdeviceptr cuda_ptr = cuMemAlloc(&cuda_ptr, numel * sizeof(T)); + if (data != nullptr) { + CUDA_CALL(cudaMemcpy(reinterpret_cast(cuda_ptr), data, numel * sizeof(T), cudaMemcpyHostToDevice)); + } + return cuda_ptr; +} + +TEST(CodeGenDebug, RunCudaSourceCode) { + common::Context::Global().ResetNameId(); + + std::string source_code = R"ROC( +extern "C" { + +__global__ +void __launch_bounds__(512) fn_relu_1_kernel(const float* __restrict__ var_1, float* __restrict__ Relu_output) +{ + for (int32_t j_0 = 0; j_0 < 8; j_0 += 1) { + for (int32_t j_1 = 0; j_1 < 1; j_1 += 1) { + for (int32_t j_2 = 0; j_2 < 1; j_2 += 1) { + for (int32_t j_3 = 0; j_3 < 8; j_3 += 1) { + for (int32_t j_4 = 0; j_4 < 1; j_4 += 1) { + for (int32_t k_0 = 0; k_0 < 1; k_0 += 1) { + for (int32_t k_1 = 0; k_1 < 7; k_1 += 1) { + for (int32_t k_2 = 0; k_2 < 4; k_2 += 1) { + for (int32_t k_3 = 0; k_3 < 4; k_3 += 1) { + for (int32_t k_4 = 0; k_4 < 1; k_4 += 1) { + for (int32_t a_0 = 0; a_0 < 16; a_0 += 1) { + for (int32_t a_1 = 0; a_1 < 1; a_1 += 1) { + for (int32_t a_2 = 0; a_2 < 1; a_2 += 1) { + for (int32_t a_3 = 0; a_3 < 1; a_3 += 1) { + for (int32_t a_4 = 0; a_4 < 7; a_4 += 1) { + Relu_output[((7 * a_0) + ((7 * a_1) + ((7 * a_2) + ((7 * a_3) + ((100352 * j_0) + ((100352 * j_1) + ((100352 * j_2) + ((12544 * j_3) + ((12544 * j_4) + ((12544 * k_0) + ((1792 * k_1) + ((448 * k_2) + ((112 * k_3) + ((112 * k_4) + a_4))))))))))))))] = max(var_1[((7 * a_0) + ((7 * a_1) + ((7 * a_2) + ((7 * a_3) + ((100352 * j_0) + ((100352 * j_1) + ((100352 * j_2) + ((12544 * j_3) + ((12544 * j_4) + ((12544 * k_0) + ((1792 * k_1) + ((448 * k_2) + ((112 * k_3) + ((112 * k_4) + a_4))))))))))))))], 0.00000000f); + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; + }; +} + +} +)ROC"; + + backends::nvrtc::Compiler compiler; + + std::string ptx = compiler(CodeGenCUDA_Dev::GetSourceHeader() + source_code); + ASSERT_FALSE(ptx.empty()); + + CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX); + CUdeviceptr var = CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); + CUdeviceptr out = CreateCudaMemory(/* shape */ {64 * 112 * 112}, /* data */ nullptr); + + void* args[] = {&var, &out}; + dim3 grid(512, 1, 1); + dim3 block(512, 1, 1); + cuda_module.LaunchKernel(/*device_id*/ 0, "fn_relu_1_kernel", grid, block, args); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc new file mode 100644 index 0000000000000..798b0a96a216d --- /dev/null +++ b/paddle/cinn/backends/compiler.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/compiler.h" + +#include + +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/common/context.h" +#ifdef CINN_WITH_CUDA +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/backends/codegen_cuda_host.h" +#include "cinn/backends/codegen_cuda_util.h" +#include "cinn/backends/nvrtc/nvrtc_util.h" +#include "cinn/runtime/cuda/cuda_module.h" +#include "cinn/runtime/cuda/cuda_util.h" +#include "cinn/runtime/flags.h" +#endif + +DECLARE_string(cinn_source_code_save_path); + +namespace cinn { +namespace backends { +using ir::Module; + +static constexpr int DebugLogMaxLen = 30000; + +SourceCodePrint::SourceCodePrint() { + if (!FLAGS_cinn_source_code_save_path.empty()) { + LOG(INFO) << "The CINN auto generated source code will writing into file: \"" << FLAGS_cinn_source_code_save_path + << "\""; + of.open(FLAGS_cinn_source_code_save_path, std::ios_base::out); + } +} + +SourceCodePrint::~SourceCodePrint() { + if (of.is_open()) { + of.close(); + } +} + +void SourceCodePrint::write(const std::string& source_code) { + std::lock_guard guard(mtx_); + if (of.is_open()) { + of << source_code << std::endl; + } else if (!FLAGS_cinn_source_code_save_path.empty()) { + LOG(WARNING) << "Failed to open \"" << FLAGS_cinn_source_code_save_path << "\", source code will print."; + if (source_code.size() > DebugLogMaxLen) { + LOG(INFO) << "[CUDA] source code-0:\n" << source_code.substr(0, DebugLogMaxLen); + for (int i = 1; i * DebugLogMaxLen < source_code.size(); ++i) { + LOG(INFO) << "[CUDA] source code-" << i << ":\n" << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen); + } + } else { + LOG(INFO) << "[CUDA] source code:\n" << source_code; + } + } +} + +void Compiler::Build(const Module& module, const std::string& code) { + if (target_.arch == Target::Arch::NVGPU) { + CompileCudaModule(module, code); + } else if (target_.arch == Target::Arch::X86) { + CompileX86Module(module); + } else { + CINN_NOT_IMPLEMENTED + } +} + +std::string Compiler::GetSourceCode(const ir::Module& module) { + if (target_.arch == Target::Arch::NVGPU) { +#ifdef CINN_WITH_CUDA + auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); + CodeGenCUDA_Dev codegen(target_); + auto source_code = codegen.Compile(device_module); + return source_code; +#else + CINN_NOT_IMPLEMENTED +#endif + } else { + CINN_NOT_IMPLEMENTED + } +} + +void Compiler::BuildDefault(const Module& module) { + if (target_.arch == Target::Arch::NVGPU) { + CompileCudaModule(module); + } else if (target_.arch == Target::Arch::X86) { + CompileX86Module(module); + } else { + CINN_NOT_IMPLEMENTED + } +} + +void Compiler::CompileCudaModule(const Module& module, const std::string& code) { +#ifdef CINN_WITH_CUDA + auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT + auto& host_module = std::get<0>(_host_module_device_module_); + auto& device_module = std::get<1>(_host_module_device_module_); + VLOG(3) << "[CUDA] host module:\n" << host_module; + + VLOG(3) << "[CUDA] device module:\n" << device_module; + std::string source_code; + if (code.empty()) { + CodeGenCUDA_Dev codegen(target_); + source_code = codegen.Compile(device_module); + } else { + source_code = code; + } + CHECK(!source_code.empty()) << "Compile CUDA C code failed from device module:\n" << device_module; + VLOG(3) << "[CUDA] C:\n" << source_code; + SourceCodePrint::GetInstance()->write(source_code); + using runtime::cuda::CUDAModule; + + nvrtc::Compiler compiler; + auto ptx = compiler(source_code); + CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << source_code; + cuda_module_.reset( + new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX)); + + RuntimeSymbols symbols; + for (auto& fn : device_module.functions()) { + std::string kernel_fn_name = fn->name; + auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); + CHECK(fn_kernel); + + symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast(fn_kernel)); + } + + engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols)); + engine_->Link(host_module); + +#else + CINN_NOT_IMPLEMENTED +#endif +} + +void Compiler::CompileX86Module(const Module& module) { engine_->Link(module); } + +void Compiler::ExportObject(const std::string& path) { engine_->ExportObject(path); } + +void* Compiler::Lookup(absl::string_view fn_name) { + CHECK(engine_); + if (engine_->Lookup(fn_name) != nullptr) { + return engine_->Lookup(fn_name); + } + return nullptr; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h new file mode 100644 index 0000000000000..bba22e60303a6 --- /dev/null +++ b/paddle/cinn/backends/compiler.h @@ -0,0 +1,94 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/execution_engine.h" +#include "cinn/backends/llvm/simple_jit.h" +#include "cinn/lang/packed_func.h" +#ifdef CINN_WITH_CUDA +#include "cinn/runtime/cuda/cuda_module.h" +#endif + +namespace cinn { +namespace backends { + +class SourceCodePrint { + public: + static SourceCodePrint* GetInstance() { + static SourceCodePrint print; + return &print; + } + + void write(const std::string& source_code); + + private: + SourceCodePrint(); + ~SourceCodePrint(); + + std::ofstream of; + std::mutex mtx_; +}; + +class Compiler final { + public: + static std::unique_ptr Create(const Target& target) { + return std::unique_ptr(new Compiler(target)); + } + + /** + * Compile and link to a CINN module. + */ + void Build(const ir::Module& module, const std::string& code = ""); + + void ExportObject(const std::string& path); + + std::string GetSourceCode(const ir::Module& module); + + void BuildDefault(const ir::Module& module); + + /** + * Retrieve a function by \p fn_name. + * @return function address or null if not exists. + */ + void* Lookup(absl::string_view fn_name); + + private: + void CompileCudaModule(const ir::Module& module, const std::string& code = ""); + + void CompileX86Module(const ir::Module& module); + + explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {} + + CINN_DISALLOW_COPY_AND_ASSIGN(Compiler); + + private: + Target target_; + std::unique_ptr engine_; + +#ifdef CINN_WITH_CUDA + std::unique_ptr cuda_module_; +#endif +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/compiler_test.cc b/paddle/cinn/backends/compiler_test.cc new file mode 100644 index 0000000000000..0393c97eb4d5a --- /dev/null +++ b/paddle/cinn/backends/compiler_test.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/compiler.h" + +#include + +#include + +#include "cinn/cinn.h" +#include "cinn/common/test_helper.h" +#include "cinn/hlir/pe/elementwise.h" +#include "cinn/hlir/pe/nn.h" +#include "cinn/runtime/use_extern_funcs.h" +#include "cinn/utils/timer.h" + +namespace cinn { +namespace backends { + +TEST(Compiler, x86) { + Expr M(1024), N(1024); + + auto create_module = [&]() { + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + auto C = Compute( + {M, N}, [=](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + return std::make_tuple(A, B, C); + }; + + { // test x86 + auto _A_B_C_ = create_module(); // NOLINT + auto& A = std::get<0>(_A_B_C_); + auto& B = std::get<1>(_A_B_C_); + auto& C = std::get<2>(_A_B_C_); + + auto stages = CreateStages({C}); + + auto fn = Lower("fn", stages, {A, B, C}); + + ir::Module::Builder builder("some_module", common::DefaultHostTarget()); + builder.AddFunction(fn); + + auto compiler = Compiler::Create(common::DefaultHostTarget()); + compiler->Build(builder.Build()); + + auto* fnp = compiler->Lookup("fn"); + ASSERT_TRUE(fnp); + + auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + + auto args = common::ArgsBuilder().Add(Ab).Add(Bb).Add(Cb).Build(); + reinterpret_cast(fnp)(args.data(), args.size()); + + // test result + auto* Ad = reinterpret_cast(Ab->memory); + auto* Bd = reinterpret_cast(Bb->memory); + auto* Cd = reinterpret_cast(Cb->memory); + for (int i = 0; i < Ab->num_elements(); i++) { + ASSERT_NEAR(Ad[i] + Bd[i], Cd[i], 1e-5); + } + } +} + +#ifdef CINN_WITH_CUDA +TEST(Compiler, cuda) { + Expr M(1024), N(1024); + + auto create_module = [&]() { + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + auto C = Compute( + {M, N}, [=](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + return std::make_tuple(A, B, C); + }; + + { // cuda + auto _A_B_C_ = create_module(); // NOLINT + auto& A = std::get<0>(_A_B_C_); + auto& B = std::get<1>(_A_B_C_); + auto& C = std::get<2>(_A_B_C_); + auto stages = CreateStages({C}); + + stages[C]->Bind(0, "blockIdx.x"); + stages[C]->Bind(1, "threadIdx.x"); + + auto fn = Lower("fn", stages, {A, B, C}); + + ir::Module::Builder builder("some_module", common::DefaultHostTarget()); + builder.AddFunction(fn); + + auto compiler = Compiler::Create(common::DefaultNVGPUTarget()); + compiler->Build(builder.Build()); + + auto* fnp = compiler->Lookup("fn"); + ASSERT_TRUE(fnp); + + auto* Ab = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* Bb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* Cb = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + + // allocate CUDA buffer + void *Ag, *Bg, *Cg; + const int num_bytes = Ab->num_elements() * sizeof(float); + cudaMalloc(&Ag, num_bytes); + cudaMalloc(&Bg, num_bytes); + cudaMalloc(&Cg, num_bytes); + + CUDA_CALL(cudaMemcpy(Ag, Ab->memory, num_bytes, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(Bg, Bb->memory, num_bytes, cudaMemcpyHostToDevice)); + CUDA_CALL(cudaMemcpy(Cg, Cb->memory, num_bytes, cudaMemcpyHostToDevice)); + + cinn_buffer_t Abb; + Abb.memory = reinterpret_cast(Ag); + cinn_buffer_t Bbb; + Bbb.memory = reinterpret_cast(Bg); + cinn_buffer_t Cbb; + Cbb.memory = reinterpret_cast(Cg); + + auto args = common::ArgsBuilder().Add(&Abb).Add(&Bbb).Add(&Cbb).Build(); + + utils::Timer timer; + timer.Start(); + void* stream = nullptr; + for (int i = 0; i < 1000; i++) { + reinterpret_cast(fnp)(args.data(), args.size(), stream); + } + + CUDA_CALL(cudaDeviceSynchronize()); + float latency = timer.Stop(); + LOG(INFO) << "latency: " << latency / 1000; + + std::vector ch(M.as_int32() * N.as_int32(), 0.f); + CUDA_CALL(cudaMemcpy(ch.data(), Cg, ch.size() * sizeof(float), cudaMemcpyDeviceToHost)); + + auto* Ad = reinterpret_cast(Ab->memory); + auto* Bd = reinterpret_cast(Bb->memory); + for (int i = 0; i < Ab->num_elements(); i++) { + ASSERT_NEAR(Ad[i] + Bd[i], ch[i], 1e-5); + } + } +} +#endif + +TEST(Compiler, sqrt) { + Expr N(100); + Expr C(10); + Expr H(10); + Expr W(10); + + Placeholder input("input", {N, C, H, W}); + Placeholder mean("mean", {C}); + Placeholder scale("scale", {C}); + Placeholder variance("variance", {C}); + Placeholder bias("bias", {C}); + float epsilon = 0.1f; + + auto A = Compute( + {N, C, H, W}, + [=](Expr n, Expr c, Expr h, Expr w) { + return (input(n, c, h, w) - mean(c)) * scale(c) / lang::Sqrt(variance(c) + Expr(epsilon)) + bias(c); + }, + "A"); + + auto B = hlir::pe::Pool2d(input, {3, 3}, {1, 1}, {1, 1, 1, 1}, "max", false, false); + + auto BB = hlir::pe::BatchNorm_NCHW(input, scale, bias, mean, variance, epsilon, "batchnorm"); + + auto stages = CreateStages({input, mean, scale, variance, A, bias, B[0], BB}); + + auto fn = Lower("fn", stages, {input, mean, scale, bias, variance, A, B[0], BB}); + + Module::Builder builder("some", common::DefaultHostTarget()); + builder.AddFunction(fn); + + auto compiler = Compiler::Create(common::DefaultHostTarget()); + compiler->Build(builder.Build()); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/cuda_util.cc b/paddle/cinn/backends/cuda_util.cc new file mode 100644 index 0000000000000..fa6f5b25f78df --- /dev/null +++ b/paddle/cinn/backends/cuda_util.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/cuda_util.h" + +#include + +#include "cinn/backends/extern_func_jit_register.h" +#include "cinn/common/target.h" + +namespace cinn { +namespace backends { + +std::string cuda_thread_axis_name(int level) { + switch (level) { + case 0: + return "threadIdx.x"; + break; + case 1: + return "threadIdx.y"; + break; + case 2: + return "threadIdx.z"; + break; + } + return ""; +} + +std::string cuda_block_axis_name(int level) { + switch (level) { + case 0: + return "blockIdx.x"; + break; + case 1: + return "blockIdx.y"; + break; + case 2: + return "blockIdx.z"; + break; + } + return ""; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/cuda_util.h b/paddle/cinn/backends/cuda_util.h new file mode 100644 index 0000000000000..f86dc177febc8 --- /dev/null +++ b/paddle/cinn/backends/cuda_util.h @@ -0,0 +1,100 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef CINN_WITH_CUDA + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cinn/runtime/cinn_runtime.h" + +#define CUDA_DRIVER_CALL(func) \ + { \ + auto status = func; \ + if (status != CUDA_SUCCESS) { \ + const char* msg; \ + cuGetErrorString(status, &msg); \ + LOG(FATAL) << "CUDA Driver Error: " #func " failed with error: " << msg; \ + } \ + } + +#define CUDA_CALL(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ + } \ + } + +#define CURAND_CALL(func) \ + { \ + auto status = func; \ + if (status != CURAND_STATUS_SUCCESS) { \ + LOG(FATAL) << "CURAND Error : " << status; \ + } \ + } + +#define CUSOLVER_CALL(func) \ + { \ + auto status = func; \ + if (status != CUSOLVER_STATUS_SUCCESS) { \ + LOG(FATAL) << "CUSOLVER Error: " << status; \ + } \ + } + +#define CUBLAS_CALL(func) \ + { \ + auto status = func; \ + if (status != CUBLAS_STATUS_SUCCESS) { \ + LOG(FATAL) << "CUBLAS Error!"; \ + } \ + } + +#define CUDNN_CALL(func) \ + { \ + auto status = func; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + LOG(FATAL) << "CUDNN Error : " << cudnnGetErrorString(status); \ + } \ + } + +#define NVRTC_CALL(func) \ + { \ + auto status = func; \ + if (status != NVRTC_SUCCESS) { \ + LOG(FATAL) << "NVRTC Error : " << nvrtcGetErrorString(status); \ + } \ + } + +namespace cinn { +namespace backends { + +// CUDA syntax for thread axis. +std::string cuda_thread_axis_name(int level); + +// CUDA syntax for block axis. +std::string cuda_block_axis_name(int level); + +} // namespace backends +} // namespace cinn + +#endif // CINN_WITH_CUDA diff --git a/paddle/cinn/backends/extern_func_emitter.cc b/paddle/cinn/backends/extern_func_emitter.cc new file mode 100644 index 0000000000000..bede4f99ff198 --- /dev/null +++ b/paddle/cinn/backends/extern_func_emitter.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/extern_func_emitter.h" + +#include +#include + +#include +#include +#include + +#include "cinn/backends/extern_func_emitter_builtin.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/runtime/cpu/host_intrinsics.h" +#include "cinn/runtime/flags.h" +#include "cinn/utils/string.h" + +DECLARE_bool(verbose_function_register); + +namespace cinn { +namespace backends { + +ExternFunctionEmitterRegistry& ExternFunctionEmitterRegistry::Global() { + static ExternFunctionEmitterRegistry x; + return x; +} + +void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name, const std::string& x) { +#ifdef CINN_WITH_DEBUG + if (FLAGS_verbose_function_register) { + RAW_LOG_INFO("Register extern function emitter [%s]", utils::GetStreamCnt(name).c_str()); + } +#endif // CINN_WITH_DEBUG + CHECK(!x.empty()) << "Extern Function name is empty."; + data_[name] = x; +} + +const std::string& ExternFunctionEmitterRegistry::Lookup(const ExternFuncID& name) const { + static const std::string not_found = ""; + auto it = data_.find(name); + if (it != data_.end()) { + return it->second; + } + return not_found; +} + +std::ostream& operator<<(std::ostream& os, const ExternFuncID& x) { + os << x.name << ":" << x.backend_id; + return os; +} + +ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {} + +const FunctionProto& ExternFunctionEmitter::func_proto() const { + auto* proto = ExternFunctionProtoRegistry::Global().Lookup(func_name()); + CHECK(proto) << "No prototype of function [" << func_name() << "]"; + return *proto; +} + +} // namespace backends +} // namespace cinn + +namespace std { + +size_t hash::operator()(const cinn::backends::ExternFuncID& x) const { + return absl::Hash{}(x.name) ^ absl::Hash{}(x.backend_id); +} + +} // namespace std diff --git a/paddle/cinn/backends/extern_func_emitter.h b/paddle/cinn/backends/extern_func_emitter.h new file mode 100644 index 0000000000000..b2b8870d51124 --- /dev/null +++ b/paddle/cinn/backends/extern_func_emitter.h @@ -0,0 +1,134 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * \file Implements the ExternFuncEmitter class, which is the base of all the emitter of extern function in the + * backends. + */ + +#pragma once +#include + +#include +#include +#include + +#include "cinn/backends/extern_func_protos.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace backends { +class ExternFuncID; +} // namespace backends +} // namespace cinn + +namespace std { +template <> +struct hash { + size_t operator()(const cinn::backends::ExternFuncID& x) const; +}; +} // namespace std + +namespace cinn { +namespace backends { + +//! IDs of backends. +static const char* backend_C = "C"; +static const char* backend_llvm_host = "llvm_host"; +static const char* backend_llvm_x86 = "llvm_x86"; +static const char* backend_nvgpu = "nvgpu"; + +/** + * \brief Base class of the emitter of all the extern functions able to trigger inside CINN CodeGen system. + * There are some common attributes and interfaces. + */ +class ExternFunctionEmitter { + public: + ExternFunctionEmitter() = default; + + virtual void BindCodeGen(void* codegen) = 0; + /** + * Get the name of the function. + */ + virtual const char* func_name() const = 0; + /** + * Emit a store node, if the call node's RetValuePacked is true, otherwise Emit a Call node. + */ + + void Emit(const ir::Call* op, bool insert_global_if_missing = false) { + insert_global_if_missing_ = insert_global_if_missing; + func_proto().AssertMatch(op); + EmitImpl(op); + } + + const FunctionProto& func_proto() const; + + /** + * \brief Tell whether the return value is packed to the argument list. + * + * e.g. Given the original IR + * \code + * s = Call(some_func, arg0) + * \endcode + * + * If this function returns true, some pass will applied and transform the IR to + * \code + * Call(some_func, get_addr(s) + * \endcode + * + * The `RetValuePacked` should be true when the external function modify an existing buffer (or some view of it) due + * to that the C language can't return a container. + */ + virtual bool RetValuePacked() const = 0; + + /** + * @return the backend identifier of this emitter. + */ + virtual const char* backend_kind() const = 0; + + protected: + virtual void EmitImpl(const ir::Call* op) = 0; + + bool insert_global_if_missing_ = false; +}; + +struct ExternFuncID { + std::string name; + std::string backend_id; + + ExternFuncID(const char* name, const char* backend_id) : name(name), backend_id(backend_id) {} + + friend std::ostream& operator<<(std::ostream& os, const ExternFuncID& x); + friend bool operator==(const ExternFuncID& a, const ExternFuncID& b) { + return a.name == b.name && a.backend_id == b.backend_id; + } +}; + +class ExternFunctionEmitterRegistry { + public: + static ExternFunctionEmitterRegistry& Global(); + + void Register(const ExternFuncID& name, const std::string& x); + + const std::string& Lookup(const ExternFuncID& name) const; + + private: + absl::flat_hash_map data_; + + ExternFunctionEmitterRegistry(); + CINN_DISALLOW_COPY_AND_ASSIGN(ExternFunctionEmitterRegistry); +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_emitter_builtin.cc b/paddle/cinn/backends/extern_func_emitter_builtin.cc new file mode 100644 index 0000000000000..087ddc6b81d33 --- /dev/null +++ b/paddle/cinn/backends/extern_func_emitter_builtin.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/extern_func_emitter_builtin.h" + +#include + +#include "cinn/backends/llvm/ir_builder_mixin.h" +#include "cinn/backends/llvm/llvm_util.h" + +namespace cinn { +namespace backends { + +void ExternFunctionLLVMEmitter::BindCodeGen(void* codegen) { codegen_ = reinterpret_cast(codegen); } + +const char* ExternFunctionLLVMEmitter::func_name() const { return fn_name_.c_str(); } + +bool ExternFunctionLLVMEmitter::RetValuePacked() const { return fn_proto().ret_type.is_void(); } + +FunctionProto& ExternFunctionLLVMEmitter::fn_proto() const { + auto* proto = ExternFunctionProtoRegistry::Global().Lookup(fn_name_); + CHECK(proto) << "No function prototype found for " << fn_name_; + return *proto; +} +llvm::FunctionType* ExternFunctionLLVMEmitter::llvm_fn_type() const { + auto* proto = ExternFunctionProtoRegistry::Global().Lookup(fn_name_); + CHECK(proto) << "No function prototype found for " << fn_name_; + + auto* llvm_ret_type = CinnTypeToLLVMType(proto->ret_type, codegen_->m()); + std::vector arg_types; + for (auto& t : proto->readonly_arg_types) { + arg_types.push_back(CinnTypeToLLVMType(t, codegen_->m())); + } + for (auto& t : proto->mutable_arg_types) { + arg_types.push_back(CinnTypeToLLVMType(t, codegen_->m())); + } + auto* fn_type = llvm::FunctionType::get(llvm_ret_type, arg_types, false); + return fn_type; +} +const char* ExternFunctionLLVMEmitter::backend_kind() const { return nullptr; } + +void ExternFunctionLLVMEmitter::EmitImpl(const ir::Call* op) { + CHECK(codegen_); + CodeGenLLVMforEmitter codegen_for_emitter(codegen_); + llvm::Function* custom_function = llvm::dyn_cast( + codegen_for_emitter.m()->getOrInsertFunction(fn_name_, llvm_fn_type()).getCallee()); + CHECK(custom_function) << "No function registered in JIT called " << fn_name_; + custom_function->setCallingConv(llvm::CallingConv::C); + + std::vector args; + for (auto& v : op->read_args) { + if (v.as_tensor()) { + args.push_back(codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); + } else { + auto* arg = codegen_for_emitter.Visit(&v); + args.push_back(arg); + } + } + for (auto& v : op->write_args) { + if (v.as_tensor()) { + args.push_back(codegen_for_emitter.GetVar(v.as_tensor()->buffer->name, false)); + } else { + auto* arg = codegen_->Visit(&v); + args.push_back(arg); + } + } + + VLOG(3) << "function type " << op->name << ": " << DumpToString(*custom_function); + + auto* command = codegen_for_emitter.b()->CreateCall(custom_function, args); + codegen_->extern_func_emit_res_ = command; + VLOG(3) << "call: " << DumpToString(*command); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_emitter_builtin.h b/paddle/cinn/backends/extern_func_emitter_builtin.h new file mode 100644 index 0000000000000..59d508e0e8906 --- /dev/null +++ b/paddle/cinn/backends/extern_func_emitter_builtin.h @@ -0,0 +1,61 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/extern_func_emitter.h" +#include "cinn/backends/extern_func_protos.h" +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/llvm_util.h" + +namespace cinn { +namespace backends { + +//! Function names + +static const char* extern_tanh_host_repr = "__cinn_host_tanh_fp32"; +static const char* extern_tanh_v_host_repr = "__cinn_host_tanh_v"; + +/** + * A bridge for the Emitters to access CodeGenLLVM's internal members. + */ +class CodeGenLLVMforEmitter : public CodeGenLLVM { + public: + explicit CodeGenLLVMforEmitter(CodeGenLLVM* x) : CodeGenLLVM(x->m(), x->b(), x->named_vars()) {} +}; + +class ExternFunctionLLVMEmitter : public ExternFunctionEmitter { + public: + explicit ExternFunctionLLVMEmitter(const std::string& fn_name) : fn_name_(fn_name) {} + + void BindCodeGen(void* codegen) override; + const char* func_name() const override; + bool RetValuePacked() const override; + const char* backend_kind() const override; + + protected: + void EmitImpl(const ir::Call* op) override; + FunctionProto& fn_proto() const; + llvm::FunctionType* llvm_fn_type() const; + + CodeGenLLVM* codegen_{}; + std::string fn_name_; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_jit_register.cc b/paddle/cinn/backends/extern_func_jit_register.cc new file mode 100644 index 0000000000000..1c9113c9f5da3 --- /dev/null +++ b/paddle/cinn/backends/extern_func_jit_register.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/extern_func_jit_register.h" + +#include + +namespace cinn { +namespace backends { + +void RegisterExternFunctionHelper(const std::string &fn_name, + std::unique_ptr &&fn_proto, + Target target, + void *fn_ptr) { + ExternFunctionProtoRegistry::Global().Register(fn_name, fn_proto.release()); + CHECK(ExternFunctionProtoRegistry::Global().Lookup(fn_name)); + + ExternFunctionEmitterRegistry::Global().Register(ExternFuncID{TargetToBackendRepr(target), fn_name.c_str()}, fn_name); + + GlobalSymbolRegistry::Global().RegisterFn(fn_name, reinterpret_cast(fn_ptr)); +} + +void RegisterExternFunction::End() { + auto fn_proto = fn_proto_builder_.Build(); + RegisterExternFunctionHelper(fn_name_, std::move(fn_proto), target_, fn_ptr_); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_jit_register.h b/paddle/cinn/backends/extern_func_jit_register.h new file mode 100644 index 0000000000000..ad738ec288667 --- /dev/null +++ b/paddle/cinn/backends/extern_func_jit_register.h @@ -0,0 +1,161 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * \file This file defines some functions and macros to help register the extern functions into JIT. + */ +#pragma once + +#include +#include +#include +#include + +#include "cinn/backends/extern_func_emitter.h" +#include "cinn/backends/extern_func_emitter_builtin.h" +#include "cinn/backends/extern_func_protos.h" +#include "cinn/backends/function_prototype.h" +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/ir_builder_mixin.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/common/macros.h" + +/** + * Helper to register an external function into CINN, including the prototype, the function address. + * @param fn__: name of the function + * @param target__: the Target. + */ +#define REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ + ::cinn::backends::RegisterExternFunction(#fn__, target__, reinterpret_cast(fn__)) + +#define REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) ::cinn::backends::RegisterExternFunction(#fn__, target__) + +/** + * Register an external function with one input and one output. + */ +#define REGISTER_EXTERN_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ + REGISTER_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() + +/** + * Register an external function with one input and one output. + */ +#define REGISTER_EXTERN_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ + .End() + +/** + * Register a sourced function(No function address, called in generated source code). + */ +#define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() + +/** + * Register a sourced function(No function address, called in generated source code). + */ +#define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ + REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ + .SetRetType() \ + .AddInputType() \ + .AddInputType() \ + .End() + +namespace cinn { +namespace backends { + +static const char* TargetToBackendRepr(Target target) { + switch (target.arch) { + case Target::Arch::X86: + return backend_llvm_host; + case Target::Arch::NVGPU: + return backend_nvgpu; + default: + CINN_NOT_IMPLEMENTED + } + return nullptr; +} + +/** + * Helper class to register an external function. + */ +struct RegisterExternFunction { + /** + * Constructor. + * @param fn_name Name of the function. + * @param target Target of the function. + * @param fn_ptr Address of the function, not valid if leave as null. + */ + RegisterExternFunction(const std::string& fn_name, Target target, void* fn_ptr = nullptr) + : fn_name_(fn_name), target_(target), fn_ptr_(fn_ptr), fn_proto_builder_(fn_name) {} + + /** + * Add an input type. + * @tparam T The input type. + * @return itself. + */ + template + RegisterExternFunction& AddInputType() { + fn_proto_builder_.AddInputType(); + return *this; + } + + /** + * Add an output type. + * @tparam T The output type. + * @return itself. + */ + template + RegisterExternFunction& AddOutputType() { + fn_proto_builder_.AddOutputType(); + return *this; + } + + /** + * Add an return type. + * @tparam T The return type. + * @return itself. + */ + template + RegisterExternFunction& SetRetType() { + fn_proto_builder_.SetRetType(); + return *this; + } + + /** + * Add an shape inference. + * @param handle The handle to help inference the shape. + * @return itself. + */ + RegisterExternFunction& SetShapeInference(FunctionProto::shape_inference_t handle) { + fn_proto_builder_.SetShapeInference(handle); + return *this; + } + + /** + * End the register, once end, futher modification is disallowed. + */ + void End(); + + private: + const std::string& fn_name_; + Target target_; + void* fn_ptr_{}; + FunctionProto::Builder fn_proto_builder_; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_protos.cc b/paddle/cinn/backends/extern_func_protos.cc new file mode 100644 index 0000000000000..58472677b3ea9 --- /dev/null +++ b/paddle/cinn/backends/extern_func_protos.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/extern_func_protos.h" + +#include +#include + +namespace cinn { +namespace backends { + +ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() { + static const std::vector extern_funcs_fp32_unary = { + "exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", "cos", + "cosh", "tan", "tanh", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", "fabs"}; + static const std::vector extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"}; + static const std::vector extern_funcs_int_binary = { + "left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"}; + static const std::vector extern_funcs_int_int_unary = {"bitwise_not"}; + for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32)); + Register(proto->name, proto); + } + for (int i = 0; i < extern_funcs_float_bool_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_float_bool_unary[i], {Float(32)}, Bool()); + Register(proto->name, proto); + } + for (int i = 0; i < extern_funcs_int_binary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_int_binary[i], {Int(32), Int(32)}, Int(32)); + Register(proto->name, proto); + } + for (int i = 0; i < extern_funcs_int_int_unary.size(); ++i) { + auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32)); + Register(proto->name, proto); + } + + auto* n = detail::CreateTanhVProto(); + Register(n->name, n); +} + +ExternFunctionProtoRegistry& ExternFunctionProtoRegistry::Global() { + static ExternFunctionProtoRegistry x; + return x; +} + +namespace detail { + +FunctionProto* CreateTanhVProto() { + return new FunctionProto( + extern_func__tanh_v, {type_of()}, {type_of()}, Void(), FunctionProto::ShapeFollowNthArgument(0)); +} + +} // namespace detail +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/extern_func_protos.h b/paddle/cinn/backends/extern_func_protos.h new file mode 100644 index 0000000000000..8b9dbd230dfd5 --- /dev/null +++ b/paddle/cinn/backends/extern_func_protos.h @@ -0,0 +1,43 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/backends/function_prototype.h" + +namespace cinn { +namespace backends { + +static const char* extern_func__tanh_v = "tanh_v"; + +class ExternFunctionProtoRegistry : public FunctionProtoRegistry { + public: + using FunctionProtoRegistry::Lookup; + using FunctionProtoRegistry::Register; + + static ExternFunctionProtoRegistry& Global(); + + private: + ExternFunctionProtoRegistry(); + CINN_DISALLOW_COPY_AND_ASSIGN(ExternFunctionProtoRegistry); +}; + +namespace detail { + +FunctionProto* CreateTanhVProto(); + +} // namespace detail + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/function_prototype.cc b/paddle/cinn/backends/function_prototype.cc new file mode 100644 index 0000000000000..87fb0ec2a40b2 --- /dev/null +++ b/paddle/cinn/backends/function_prototype.cc @@ -0,0 +1,130 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/function_prototype.h" + +#include + +#include + +#include "cinn/ir/tensor.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(verbose_function_register); + +namespace cinn { +namespace backends { + +bool FunctionProto::Match(const ir::Call *op) const { + if (name != op->name) return false; + if (ret_type != op->type()) return false; + if (op->read_args.size() != readonly_arg_types.size()) return false; + if (op->write_args.size() != mutable_arg_types.size()) return false; + + for (int i = 0; i < op->read_args.size(); i++) { + if (op->read_args[i].type() != readonly_arg_types[i]) return false; + } + for (int i = 0; i < op->write_args.size(); i++) { + if (op->write_args[i].type() != mutable_arg_types[i]) return false; + } + return true; +} + +void FunctionProto::AssertMatch(const ir::Call *op) const { + CHECK_EQ(name, op->name); + CHECK_EQ(ret_type, op->type()) << "function proto " << name << " check failed"; + CHECK_EQ(op->read_args.size(), readonly_arg_types.size()) << "function proto " << name << " check failed"; + CHECK_EQ(op->write_args.size(), mutable_arg_types.size()) << "function proto " << name << " check failed"; + + auto get_type = [](Expr u) { + if (u.as_tensor() || u.as_buffer()) { + Type t = u.type(); + return t.set_cpp_handle(); + } + return u.type(); + }; + for (int i = 0; i < op->read_args.size(); i++) { + if (readonly_arg_types[i] == type_of()) { + if (!op->read_args[i].as_tensor()) continue; + } else { + CHECK_EQ(get_type(op->read_args[i]), readonly_arg_types[i]); + } + } + for (int i = 0; i < op->write_args.size(); i++) { + if (mutable_arg_types[i] == type_of()) { + if (!op->write_args[i].as_tensor()) continue; + } else { + CHECK_EQ(get_type(op->write_args[i]), mutable_arg_types[i]); + } + } +} + +void FunctionProto::CheckValid() { + if (ret_type.is_void()) { + CHECK(!mutable_arg_types.empty()) + << "A void function should have at least one mutable argument to output something"; + } else { + CHECK(mutable_arg_types.empty()) << "A function with return should not have mutable argument"; + } +} + +FunctionProto::shape_inference_t FunctionProto::ShapeFollowNthArgument(int n) { + return [=](const std::vector &args, int value_offset) { + CHECK_LT(n, args.size()); + auto x = args[n].as_tensor(); + CHECK(x); + return x->shape; + }; +} + +FunctionProto::FunctionProto(const std::string &name, + const std::vector &readonly_arg_types, + const std::vector &mutable_arg_types, + Type ret_type, + FunctionProto::shape_inference_t shape_inference) + : name(name), + readonly_arg_types(readonly_arg_types), + mutable_arg_types(mutable_arg_types), + ret_type(ret_type), + shape_inference(shape_inference) { + CheckValid(); +} + +FunctionProto *FunctionProtoRegistry::Lookup(const std::string &name) { + auto it = data_.find(name); + if (it != data_.end()) { + return it->second.get(); + } + return nullptr; +} + +FunctionProto *FunctionProtoRegistry::Register(absl::string_view name, FunctionProto *x) { +#ifdef CINN_WITH_DEBUG + if (FLAGS_verbose_function_register) { + RAW_LOG_INFO("Register function prototype [%s]", name.data()); + } +#endif // CINN_WITH_DEBUG + data_.emplace(name, std::unique_ptr(x)); + return x; +} + +std::string FunctionProtoRegistry::debug_string() const { + std::stringstream ss; + for (auto &item : data_) { + ss << item.first << "\n"; + } + return ss.str(); +} +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/function_prototype.h b/paddle/cinn/backends/function_prototype.h new file mode 100644 index 0000000000000..2ec058fa7edb2 --- /dev/null +++ b/paddle/cinn/backends/function_prototype.h @@ -0,0 +1,130 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace backends { + +struct FunctionProto { + using shape_inference_t = + std::function /*shape*/ (const std::vector& /*arguments*/, int /*value_offset*/)>; + + std::string name; + std::vector readonly_arg_types; + std::vector mutable_arg_types; + Type ret_type; + + // Inference the output's shape. + shape_inference_t shape_inference; + + /** + * Constructor for multiple output function. + * @param name Name of the function. + * @param readonly_arg_types The input types. + * @param mutable_arg_types The output types. + * @param ret_type The return type, default to Void(). + * @param shape_inference The shape inference for each of the output tensor. + */ + FunctionProto(const std::string& name, + const std::vector& readonly_arg_types, + const std::vector& mutable_arg_types, + Type ret_type = Void(), + shape_inference_t shape_inference = shape_inference_t()); + + /** + * Constructor for single output function. + * @param name Name of the function. + * @param input_types The input types. + * @param ret_type The return type. + */ + FunctionProto(const std::string& name, const std::vector& input_types, Type ret_type) + : name(name), readonly_arg_types(input_types), ret_type(ret_type) {} + + /** + * Tell whether the Call \p op matches the function prototype. + */ + bool Match(const ir::Call* op) const; + + /** + * Assert the call should match the function prototype. + */ + void AssertMatch(const ir::Call* op) const; + + struct Builder { + explicit Builder(const std::string& name) { + data_.reset(new FunctionProto); + data_->name = name; + } + template + Builder& SetRetType() { + data_->ret_type = type_of(); + return *this; + } + template + Builder& AddInputType() { + data_->readonly_arg_types.push_back(type_of()); + return *this; + } + template + Builder& AddOutputType() { + data_->mutable_arg_types.push_back(type_of()); + return *this; + } + Builder& SetShapeInference(shape_inference_t fn) { + data_->shape_inference = fn; + return *this; + } + + std::unique_ptr Build() { return std::move(data_); } + + private: + std::unique_ptr data_; + }; + + /** + * All the outputs use the n-th argument's shape. + */ + static shape_inference_t ShapeFollowNthArgument(int n); + + protected: + void CheckValid(); + + FunctionProto() = default; +}; + +class FunctionProtoRegistry { + public: + FunctionProto* Register(absl::string_view name, FunctionProto* x); + + FunctionProto* Lookup(const std::string& name); + + std::string debug_string() const; + + private: + absl::flat_hash_map> data_; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/generated1.cu b/paddle/cinn/backends/generated1.cu new file mode 100644 index 0000000000000..88459ce83f588 --- /dev/null +++ b/paddle/cinn/backends/generated1.cu @@ -0,0 +1,15 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/_generated1.cu" diff --git a/paddle/cinn/backends/generated_module1.cc b/paddle/cinn/backends/generated_module1.cc new file mode 100644 index 0000000000000..4c74a485bec27 --- /dev/null +++ b/paddle/cinn/backends/generated_module1.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/_generated_module1.cc" diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc new file mode 100644 index 0000000000000..0d11d4230d911 --- /dev/null +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -0,0 +1,3019 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_schedule.h" + +#include +#include + +#include +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/backends/codegen_cuda_dev.h" +#include "cinn/cinn.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/lower.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/remove_schedule_block.h" +#include "cinn/optim/unroll_loops.h" +#include "cinn/optim/vectorize_loops.h" + +namespace cinn { +namespace backends { + +TEST(IrSchedule, split_and_fuse1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_split_and_fuse1", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto fused = ir_sch.Fuse("B", {0, 1}); + auto splited = ir_sch.Split(fused, {4, -1}); + + auto loops = ir_sch.GetLoops("B"); + fused = ir_sch.Fuse(loops); + splited = ir_sch.Split(fused, {256, -1}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_split_and_fuse1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i_j_fused_i_j_fused_0_fused = 0; i_j_fused_i_j_fused_0_fused < 256; i_j_fused_i_j_fused_0_fused += 1) { + for (int32_t i_j_fused_i_j_fused_0_fused_0 = 0; i_j_fused_i_j_fused_0_fused_0 < 4; i_j_fused_i_j_fused_0_fused_0 += 1) { + B[(((i_j_fused_i_j_fused_0_fused / 8) * 32) + (((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31))] = A[(((i_j_fused_i_j_fused_0_fused / 8) * 32) + (((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31))]; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, split_and_fuse2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_split_and_fuse2", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + + auto fused = ir_sch.Fuse(loops); + auto splited = ir_sch.Split(fused, {-1, 20}); + VLOG(3) << "After split {-1, 20}, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(3) << "split_and_fuse2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_split_and_fuse2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i_j_fused = 0; i_j_fused < 52; i_j_fused += 1) { + for (int32_t i_j_fused_0 = 0; i_j_fused_0 < 20; i_j_fused_0 += 1) { + if ((((20 * i_j_fused) + i_j_fused_0) < 1024)) { + B[((20 * i_j_fused) + i_j_fused_0)] = A[((20 * i_j_fused) + i_j_fused_0)]; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, reorder1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_reorder1", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto splited = ir_sch.Split("B", 0, {-1, 4}); + splited = ir_sch.Split("B", 2, {-1, 2}); + + auto loops = ir_sch.GetLoops("B"); + ir_sch.Reorder({loops[4], loops[0]}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(3) << "reorder1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_reorder1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t k = 0; k < 32; k += 1) { + for (int32_t i_0 = 0; i_0 < 4; i_0 += 1) { + for (int32_t j = 0; j < 16; j += 1) { + for (int32_t j_0 = 0; j_0 < 2; j_0 += 1) { + for (int32_t i = 0; i < 8; i += 1) { + B[((4096 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))] = A[((4096 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))]; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, reorder2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_reorder2", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto splited = ir_sch.Split("B", 0, {-1, 4}); + splited = ir_sch.Split("B", 2, {-1, 2}); + + ir_sch.Reorder("B", {4, 2, 3, 1, 0}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(3) << "reorder2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_reorder2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t k = 0; k < 32; k += 1) { + for (int32_t j = 0; j < 16; j += 1) { + for (int32_t j_0 = 0; j_0 < 2; j_0 += 1) { + for (int32_t i_0 = 0; i_0 < 4; i_0 += 1) { + for (int32_t i = 0; i < 8; i += 1) { + B[((4096 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))] = A[((4096 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))]; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, reorder3) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_reorder3", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto all_blocks = ir_sch.GetAllBlocks(); + auto loops = ir_sch.GetLoops(all_blocks[0]); + + auto splited = ir_sch.Split(loops[0], {-1, 5}); + splited = ir_sch.Split("B", 2, {-1, 2}); + + ir_sch.Reorder("B", {3, 1, 2, 0, 4}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(3) << "reorder3 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_reorder3(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t j_0 = 0; j_0 < 2; j_0 += 1) { + for (int32_t i_0 = 0; i_0 < 5; i_0 += 1) { + for (int32_t j = 0; j < 16; j += 1) { + for (int32_t i = 0; i < 7; i += 1) { + if ((((5 * i) + i_0) < 32)) { + for (int32_t k = 0; k < 32; k += 1) { + B[((5120 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))] = A[((5120 * i) + ((1024 * i_0) + ((64 * j) + ((32 * j_0) + k))))]; + }; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, reorder4) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_reorder4", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto all_blocks = ir_sch.GetAllBlocks(); + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops(block_b); + + auto splited = ir_sch.Split("B", 0, {-1, 10}); + splited = ir_sch.Split("B", 2, {-1, 5}); + + ir_sch.Reorder("B", {0, 2, 1, 3, 4}); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(3) << "reorder4 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_reorder4(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t j = 0; j < 7; j += 1) { + for (int32_t i_0 = 0; i_0 < 10; i_0 += 1) { + if ((((10 * i) + i_0) < 32)) { + for (int32_t j_0 = 0; j_0 < 5; j_0 += 1) { + if ((((5 * j) + j_0) < 32)) { + for (int32_t k = 0; k < 32; k += 1) { + B[((10240 * i) + ((1024 * i_0) + ((160 * j) + ((32 * j_0) + k))))] = A[((10240 * i) + ((1024 * i_0) + ((160 * j) + ((32 * j_0) + k))))]; + }; + }; + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +#ifdef CINN_USE_OPENMP +TEST(IrSchedule, parallel) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_parallel", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK(!loops.empty()); + ir_sch.Parallel(loops[0]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_parallel(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + int num_task = max_concurrency(); + omp_set_num_threads(num_task); + auto flambda = [=](int task_id, int num_task) -> int { + int n_per_task = (((32 + num_task) - 1) / num_task); + for (int32_t i = (task_id * n_per_task); i < 32 && i < ((task_id + 1) * n_per_task); i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B[((32 * i) + j)] = A[((32 * i) + j)]; + }; + } + return 0; + }; +#pragma omp parallel num_threads(num_task) + { + int task_id = omp_get_thread_num(); + flambda(task_id, num_task); + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} +#endif // CINN_USE_OPENMP + +TEST(IrSchedule, vectorize) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_vectorize", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 2U); + ir_sch.Vectorize(loops[1], 16); + std::string origin = utils::GetStreamCnt(func[0]); + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_vectorize (_A, _B) +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + vectorize[16] for (j, 0, 32) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = A[i0, i1] + } + } + } + } +} +)ROC")); + optim::VectorizeLoops(&func[0]->body, target); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_vectorize(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 2; j += 1) { + B[StackVec<16,int32_t>::Ramp(((32 * i) + (16 * j)), 1, 16)] = StackedVec::Load(A,((32 * i) + (16 * j))); + }; + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, unroll) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(2); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_unroll", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 2U); + ir_sch.Unroll(loops[1]); + std::string origin = utils::GetStreamCnt(func[0]); + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_unroll (_A, _B) +{ + ScheduleBlock(root) + { + serial for (i, 0, 32) + { + unroll for (j, 0, 2) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = A[i0, i1] + } + } + } + } +} +)ROC")); + optim::UnrollLoop(&func[0]->body); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_unroll(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 32; i += 1) { + B[(2 * i)] = A[(2 * i)]; + B[(1 + (2 * i))] = A[(1 + (2 * i))]; + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, bind) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(2); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_bind", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 2U); + ir_sch.Bind(loops[0], "blockIdx.x"); + std::string origin = utils::GetStreamCnt(func[0]); + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_bind (_A, _B) +{ + ScheduleBlock(root) + { + thread_bind[blockIdx.x] for (i, 0, 32) + { + serial for (j, 0, 2) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind(i, j) + B[i0, i1] = A[i0, i1] + } + } + } + } +} +)ROC")); +} + +TEST(IrSchedule, simple_compute_at) { + Context::Global().ResetNameId(); + Expr M(128); + Expr N(10); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return B(i, j); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_simple_compute_at", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto fused = ir_sch.Fuse("B", {0, 1}); + auto splited = ir_sch.Split(fused, {-1, 1024}); + + fused = ir_sch.Fuse("C", {0, 1}); + splited = ir_sch.Split(fused, {-1, 1024}); + auto block_b = ir_sch.GetBlock("B"); + ir_sch.SimpleComputeAt(block_b, splited[1]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "simple_compute_at source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_simple_compute_at(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 128, 10 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { + for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { + if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { + { + B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; + C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; + } + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at0) { + Context::Global().ResetNameId(); + Expr M(128); + Expr N(10); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return B(i, j); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_at0", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto fused = ir_sch.Fuse("B", {0, 1}); + auto splited = ir_sch.Split(fused, {-1, 1024}); + + fused = ir_sch.Fuse("C", {0, 1}); + splited = ir_sch.Split(fused, {-1, 1024}); + auto block_b = ir_sch.GetBlock("B"); + ir_sch.ComputeAt(block_b, splited[1]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at0 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_at0(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 128, 10 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i_j_fused_1 = 0; i_j_fused_1 < 2; i_j_fused_1 += 1) { + for (int32_t i_j_fused_2 = 0; i_j_fused_2 < 1024; i_j_fused_2 += 1) { + if ((((1024 * i_j_fused_1) + i_j_fused_2) < 1280)) { + { + B[((1024 * i_j_fused_1) + i_j_fused_2)] = A[((1024 * i_j_fused_1) + i_j_fused_2)]; + C[((1024 * i_j_fused_1) + i_j_fused_2)] = B[((1024 * i_j_fused_1) + i_j_fused_2)]; + } + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_at1", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[1]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_at1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t ax0 = 0; ax0 < 32; ax0 += 1) { + B[((1024 * i) + ((32 * j) + ax0))] = A[((1024 * i) + ((32 * j) + ax0))]; + }; + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = B[((1024 * i) + ((32 * j) + k))]; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at2) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {M, M}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {N, N}, [&](Var i, Var j) { return B(i + j, i + j); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_at2", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[0]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_at2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 64, 64 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t ax0 = 0; ax0 < 32; ax0 += 1) { + for (int32_t ax1 = 0; ax1 < 32; ax1 += 1) { + B[((64 * ax0) + ((64 * i) + (ax1 + i)))] = A[((64 * ax0) + ((64 * i) + (ax1 + i)))]; + }; + }; + for (int32_t j = 0; j < 32; j += 1) { + C[((32 * i) + j)] = B[((65 * i) + (65 * j))]; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at3) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {M, M}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {M, M}, [&](Var i, Var j) { return B(i, j); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_at3", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + + auto fused = ir_sch.Fuse("C", {0, 1}); + auto splited = ir_sch.Split(fused, {32, -1}); + + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[0]); + + VLOG(1) << "After ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at3 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_at3(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 64, 64 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i_j_fused = 0; i_j_fused < 32; i_j_fused += 1) { + for (int32_t ax0 = 0; ax0 < 2; ax0 += 1) { + for (int32_t ax1 = 0; ax1 < 64; ax1 += 1) { + B[((64 * ax0) + ((128 * i_j_fused) + ax1))] = A[((64 * ax0) + ((128 * i_j_fused) + ax1))]; + }; + }; + for (int32_t i_j_fused_0 = 0; i_j_fused_0 < 128; i_j_fused_0 += 1) { + C[((128 * i_j_fused) + i_j_fused_0)] = B[((128 * i_j_fused) + i_j_fused_0)]; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +#ifdef CINN_WITH_CUDA +TEST(IrSchedule, compute_at4) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_compute_at4", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[1]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at4 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_compute_at4(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 32768 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t ax0 = 0; ax0 < 32; ax0 += 1) { + B[((1024 * i) + ((32 * j) + ax0))] = A[((1024 * i) + ((32 * j) + ax0))]; + }; + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = B[((1024 * i) + ((32 * j) + k))]; + }; + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at5) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {M, M}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {N, N}, [&](Var i, Var j) { return B(i + j, i + j); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_compute_at5", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[0]); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at5 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + + R"ROC(__global__ +void test_compute_at5(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 4096 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t ax0 = 0; ax0 < 32; ax0 += 1) { + for (int32_t ax1 = 0; ax1 < 32; ax1 += 1) { + B[((64 * ax0) + ((64 * i) + (ax1 + i)))] = A[((64 * ax0) + ((64 * i) + (ax1 + i)))]; + }; + }; + for (int32_t j = 0; j < 32; j += 1) { + C[((32 * i) + j)] = B[((65 * i) + (65 * j))]; + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_at6) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {M, M}, [&](Var i, Var j) { return A(i, j); }, "B"); + auto C = Compute( + {M, M}, [&](Var i, Var j) { return B(i, j); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_compute_at6", stages, {A, C}, {}, {}, nullptr, target, true); + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + + auto fused = ir_sch.Fuse("C", {0, 1}); + auto splited = ir_sch.Split(fused, {32, -1}); + + auto loops = ir_sch.GetLoops("C"); + + ir_sch.ComputeAt(block_b, loops[1]); + + VLOG(1) << "After ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_at6 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_compute_at6(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 4096 ]; + float* B = _B_temp_buffer; + for (int32_t i_j_fused = 0; i_j_fused < 32; i_j_fused += 1) { + for (int32_t i_j_fused_0 = 0; i_j_fused_0 < 128; i_j_fused_0 += 1) { + B[((128 * i_j_fused) + i_j_fused_0)] = A[((128 * i_j_fused) + i_j_fused_0)]; + C[((128 * i_j_fused) + i_j_fused_0)] = B[((128 * i_j_fused) + i_j_fused_0)]; + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} +#endif + +TEST(IrSchedule, cache_read1) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + Expr P(16); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {N, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + auto C = Compute( + {P, P}, [&](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_cache_read1", stages, {A, C}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto a_cache = ir_sch.CacheRead(block_b, 0, "local"); + auto block_c = ir_sch.GetBlock("C"); + auto b_cache = ir_sch.CacheRead(block_c, 0, "local"); + + VLOG(1) << "After CacheRead, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_read1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_cache_read1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t cache_ax0 = 0; cache_ax0 < 32; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + A_local_temp_buffer[((64 * cache_ax0) + cache_ax1)] = A[((64 * cache_ax0) + cache_ax1)]; + }; + }; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((64 * i) + j)]); + }; + }; + for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 16; cache_ax0_0 += 1) { + for (int32_t cache_ax1_0 = 0; cache_ax1_0 < 16; cache_ax1_0 += 1) { + B_local_temp_buffer[((32 * cache_ax0_0) + cache_ax1_0)] = B[((32 * cache_ax0_0) + cache_ax1_0)]; + }; + }; + for (int32_t i = 0; i < 16; i += 1) { + for (int32_t j = 0; j < 16; j += 1) { + C[((16 * i) + j)] = (1.00000000f + B_local_temp_buffer[((32 * i) + j)]); + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, cache_read2) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_cache_read2", stages, {A, B}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + + auto a_cache = ir_sch.CacheRead(block_b, 0, "local"); + + auto loops = ir_sch.GetLoops("B"); + ir_sch.ComputeAt(a_cache, loops[1]); + + VLOG(1) << "After CacheRead and ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_read2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_cache_read2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + A_local_temp_buffer[((32 * i) + j)] = A[((32 * i) + j)]; + B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((32 * i) + j)]); + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, cache_write1) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_cache_write1", stages, {A, C}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + auto block_c = ir_sch.GetBlock("C"); + auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); + + VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_write1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_cache_write1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 64, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]); + }; + }; + for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)]; + }; + }; + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]); + }; + }; + for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 64; cache_ax0_0 += 1) { + for (int32_t cache_ax1_0 = 0; cache_ax1_0 < 32; cache_ax1_0 += 1) { + C[((32 * cache_ax0_0) + cache_ax1_0)] = C_local_temp_buffer[((32 * cache_ax0_0) + cache_ax1_0)]; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, cache_write2) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_cache_write2", stages, {A, B}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + auto loops = ir_sch.GetLoops("B"); + ir_sch.ComputeAt(b_cache, loops[1]); + + VLOG(1) << "After CacheWrite and ComputeAt, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_write2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_cache_write2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)] = (2.00000000f * A[((32 * cache_ax0) + cache_ax1)]); + B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)]; + }; + }; + cinn_buffer_free((void*)(0), _B); +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +#ifdef CINN_WITH_CUDA +TEST(IrSchedule, cache_read3) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + Expr P(16); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, M}); + auto B = Compute( + {N, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + auto C = Compute( + {P, P}, [&](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_cache_read3", stages, {A, C}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto a_cache = ir_sch.CacheRead(block_b, 0, "local"); + auto block_c = ir_sch.GetBlock("C"); + auto b_cache = ir_sch.CacheRead(block_c, 0, "local"); + auto loops_c = ir_sch.GetLoops("C"); + ir_sch.SyncThreads(loops_c[1], false); + auto loops_b = ir_sch.GetLoops("B"); + ir_sch.SyncThreads(loops_b[1]); + + VLOG(1) << "After CacheRead, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_read3 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_cache_read3(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 1024 ]; + float* B = _B_temp_buffer; + for (int32_t cache_ax0 = 0; cache_ax0 < 32; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + A_local_temp_buffer[((64 * cache_ax0) + cache_ax1)] = A[((64 * cache_ax0) + cache_ax1)]; + }; + }; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B[((32 * i) + j)] = (2.00000000f * A_local_temp_buffer[((64 * i) + j)]); + }; + __syncthreads(); + }; + for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 16; cache_ax0_0 += 1) { + for (int32_t cache_ax1_0 = 0; cache_ax1_0 < 16; cache_ax1_0 += 1) { + B_local_temp_buffer[((32 * cache_ax0_0) + cache_ax1_0)] = B[((32 * cache_ax0_0) + cache_ax1_0)]; + }; + }; + for (int32_t i = 0; i < 16; i += 1) { + __syncthreads(); + for (int32_t j = 0; j < 16; j += 1) { + C[((16 * i) + j)] = (1.00000000f + B_local_temp_buffer[((32 * i) + j)]); + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, cache_write3) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("shared"); + + auto func = cinn::lang::LowerVec("test_cache_write3", stages, {A, C}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + auto block_c = ir_sch.GetBlock("C"); + auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); + auto loops_c = ir_sch.GetLoops("C"); + ir_sch.SyncThreads(loops_c[0], false); + auto loops_b = ir_sch.GetLoops("B"); + ir_sch.SyncThreads(loops_b[0]); + + VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_write3 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_cache_write3(const float* __restrict__ A, float* __restrict__ C) +{ + __shared__ float _B_temp_buffer [ 2048 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]); + }; + }; + for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)]; + }; + }; + __syncthreads(); + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]); + }; + }; + __syncthreads(); + for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 64; cache_ax0_0 += 1) { + for (int32_t cache_ax1_0 = 0; cache_ax1_0 < 32; cache_ax1_0 += 1) { + C[((32 * cache_ax0_0) + cache_ax1_0)] = C_local_temp_buffer[((32 * cache_ax0_0) + cache_ax1_0)]; + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, sync_threads) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("shared"); + + auto func = cinn::lang::LowerVec("test_sync_threads", stages, {A, C}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + auto block_c = ir_sch.GetBlock("C"); + auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); + block_c = ir_sch.GetBlock("C"); + ir_sch.SyncThreads(block_c, false); + block_b = ir_sch.GetBlock("B"); + ir_sch.SyncThreads(block_b); + + VLOG(1) << "After CacheWrite and SyncThreads, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_sync_threads(const float* __restrict__ A, float* __restrict__ C) +{ + __shared__ float _B_temp_buffer [ 2048 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B_local_temp_buffer[((32 * i) + j)] = (2.00000000f * A[((32 * i) + j)]); + }; + }; + for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)]; + __syncthreads(); + }; + }; + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + C_local_temp_buffer[((32 * i) + j)] = (1.00000000f + B[((32 * i) + j)]); + }; + }; + for (int32_t cache_ax0_0 = 0; cache_ax0_0 < 64; cache_ax0_0 += 1) { + for (int32_t cache_ax1_0 = 0; cache_ax1_0 < 32; cache_ax1_0 += 1) { + __syncthreads(); + C[((32 * cache_ax0_0) + cache_ax1_0)] = C_local_temp_buffer[((32 * cache_ax0_0) + cache_ax1_0)]; + }; + }; +} + +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} +#endif + +TEST(IrSchedule, cache_write4) { + Context::Global().ResetNameId(); + Expr M(64); + Expr N(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, N}); + Var k(32, "k0"); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, j, k), {k}); }, "B"); + + auto stages = CreateStages({A, B}); + + auto func = cinn::lang::LowerVec("test_cache_write4", stages, {A, B}, {}, {}, nullptr, target, true); + + CHECK_EQ(func.size(), 1U); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + auto loops = ir_sch.GetLoops("B"); + + VLOG(1) << "After CacheWrite, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "cache_write4 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_cache_write4(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* B__reduce_init = ((float*)(_B->memory)); + for (int32_t i = 0; i < 64; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + B__reduce_init[((32 * i) + j)] = 0.00000000f; + for (int32_t k0 = 0; k0 < 32; k0 += 1) { + B_local_temp_buffer[((32 * i) + j)] = (B_local_temp_buffer[((32 * i) + j)] + A[((1024 * i) + ((32 * j) + k0))]); + }; + }; + }; + for (int32_t cache_ax0 = 0; cache_ax0 < 64; cache_ax0 += 1) { + for (int32_t cache_ax1 = 0; cache_ax1 < 32; cache_ax1 += 1) { + B[((32 * cache_ax0) + cache_ax1)] = B_local_temp_buffer[((32 * cache_ax0) + cache_ax1)]; + }; + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, rfactor) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(2); + Expr K(16); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, K}); + Var j(2, "j0"); + Var k(16, "k0"); + auto B = Compute( + {M}, + [&](Var i) { + return lang::ReduceSum(A(i, j, k), {j, k}); + }, + "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 3U); + auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_rfactor (_A, _B) +{ + ScheduleBlock(root) + { + { + serial for (rf_k0, 0, 16) + { + serial for (i, 0, 32) + { + ScheduleBlock(rf_B__reduce_init) + { + i0, i1_0 = axis.bind(i, rf_k0) + rf_B__reduce_init[i1_0, i0] = 0.00000000f + } + serial for (j0, 0, 2) + { + ScheduleBlock(rf_B) + { + i0_0, i1, i2 = axis.bind(i, j0, rf_k0) + rf_B[i2, i0_0] = (rf_B[i2, i0_0] + A[i0_0, i1, i2]) + } + } + } + } + serial for (i, 0, 32) + { + ScheduleBlock(B__reduce_init) + { + i0 = axis.bind(i) + B__reduce_init[i0] = 0.00000000f + } + serial for (k0, 0, 16) + { + ScheduleBlock(B) + { + i0_0, i2 = axis.bind(i, k0) + B[i0_0] = (B[i0_0] + rf_B[i2, i0_0]) + } + } + } + } + } +} +)ROC")); + // optimze pass: add temp buffers + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_rfactor(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* rf__B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 16, 32 }); + cinn_buffer_malloc((void*)(0), _B); + cinn_buffer_malloc((void*)(0), rf__B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* B__reduce_init = ((float*)(_B->memory)); + float* rf_B = ((float*)(rf__B->memory)); + float* rf_B__reduce_init = ((float*)(rf__B->memory)); + for (int32_t rf_k0 = 0; rf_k0 < 16; rf_k0 += 1) { + for (int32_t i = 0; i < 32; i += 1) { + rf_B__reduce_init[((32 * rf_k0) + i)] = 0.00000000f; + for (int32_t j0 = 0; j0 < 2; j0 += 1) { + rf_B[((32 * rf_k0) + i)] = (rf_B[((32 * rf_k0) + i)] + A[((32 * i) + ((16 * j0) + rf_k0))]); + }; + }; + }; + for (int32_t i = 0; i < 32; i += 1) { + B__reduce_init[i] = 0.00000000f; + for (int32_t k0 = 0; k0 < 16; k0 += 1) { + B[i] = (B[i] + rf_B[((32 * k0) + i)]); + }; + }; + cinn_buffer_free((void*)(0), rf__B); + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, rfactor1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(2); + Expr K(16); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, K}); + Var j(2, "j0"); + Var k(16, "k0"); + auto B = Compute( + {M}, + [&](Var i) { + return lang::ReduceSum(A(i, j, k), {j, k}); + }, + "B"); + + auto stages = CreateStages({A, B}); + auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("B"); + CHECK_EQ(loops.size(), 3U); + auto new_rf_tensor = ir_sch.Rfactor(loops[1], 1); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_rfactor (_A, _B) +{ + ScheduleBlock(root) + { + { + serial for (i, 0, 32) + { + serial for (rf_j0, 0, 2) + { + ScheduleBlock(rf_B__reduce_init) + { + i0, i1_0 = axis.bind(i, rf_j0) + rf_B__reduce_init[i0, i1_0] = 0.00000000f + } + serial for (k0, 0, 16) + { + ScheduleBlock(rf_B) + { + i0_0, i1, i2 = axis.bind(i, rf_j0, k0) + rf_B[i0_0, i1] = (rf_B[i0_0, i1] + A[i0_0, i1, i2]) + } + } + } + } + serial for (i, 0, 32) + { + ScheduleBlock(B__reduce_init) + { + i0 = axis.bind(i) + B__reduce_init[i0] = 0.00000000f + } + serial for (j0, 0, 2) + { + ScheduleBlock(B) + { + i0_0, i1 = axis.bind(i, j0) + B[i0_0] = (B[i0_0] + rf_B[i0_0, i1]) + } + } + } + } + } +} +)ROC")); + // optimze pass: add temp buffers + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_rfactor(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* rf__B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 2 }); + cinn_buffer_malloc((void*)(0), _B); + cinn_buffer_malloc((void*)(0), rf__B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* B__reduce_init = ((float*)(_B->memory)); + float* rf_B = ((float*)(rf__B->memory)); + float* rf_B__reduce_init = ((float*)(rf__B->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t rf_j0 = 0; rf_j0 < 2; rf_j0 += 1) { + rf_B__reduce_init[((2 * i) + rf_j0)] = 0.00000000f; + for (int32_t k0 = 0; k0 < 16; k0 += 1) { + rf_B[((2 * i) + rf_j0)] = (rf_B[((2 * i) + rf_j0)] + A[((32 * i) + ((16 * rf_j0) + k0))]); + }; + }; + }; + for (int32_t i = 0; i < 32; i += 1) { + B__reduce_init[i] = 0.00000000f; + for (int32_t j0 = 0; j0 < 2; j0 += 1) { + B[i] = (B[i] + rf_B[((2 * i) + j0)]); + }; + }; + cinn_buffer_free((void*)(0), rf__B); + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, rfactor2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(2); + Expr K(16); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + Var k(16, "k0"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + auto stages = CreateStages({A, B, C}); + auto func = cinn::lang::LowerVec("test_rfactor", stages, {A, B, C}, {}, {}, nullptr, target, true); + CHECK(!func.empty()); + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("C"); + CHECK_EQ(loops.size(), 3U); + auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); + auto* new_rf_tensor_ref = new_rf_tensor.As(); + CHECK(new_rf_tensor_ref); + CHECK(new_rf_tensor_ref->buffer.defined()); + func[0]->temp_bufs.push_back(new_rf_tensor_ref->buffer); + func[0]->PrepareBufferCastExprs(); + std::string origin = utils::GetStreamCnt(func[0]); + + EXPECT_EQ(origin, utils::Trim(R"ROC( +function test_rfactor (_A, _B, _C) +{ + ScheduleBlock(root) + { + { + serial for (rf_k0, 0, 16) + { + serial for (i, 0, 32) + { + serial for (j, 0, 2) + { + ScheduleBlock(rf_C__reduce_init) + { + i0, i1, i2_0 = axis.bind(i, j, rf_k0) + rf_C__reduce_init[i2_0, i0, i1] = 0.00000000f + } + ScheduleBlock(rf_C) + { + i0_0, i1_0, i2 = axis.bind(i, j, rf_k0) + rf_C[i2, i0_0, i1_0] = (rf_C[i2, i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0])) + } + } + } + } + serial for (i, 0, 32) + { + serial for (j, 0, 2) + { + ScheduleBlock(C__reduce_init) + { + i0, i1 = axis.bind(i, j) + C__reduce_init[i0, i1] = 0.00000000f + } + serial for (k0, 0, 16) + { + ScheduleBlock(C) + { + i0_0, i1_0, i2 = axis.bind(i, j, k0) + C[i0_0, i1_0] = (C[i0_0, i1_0] + rf_C[i2, i0_0, i1_0]) + } + } + } + } + } + } +} +)ROC")); + // optimze pass: add temp buffers + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_rfactor(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_t* rf__C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 16, 32, 2 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), rf__C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C__reduce_init = ((float*)(_C->memory)); + float* rf_C = ((float*)(rf__C->memory)); + float* rf_C__reduce_init = ((float*)(rf__C->memory)); + for (int32_t rf_k0 = 0; rf_k0 < 16; rf_k0 += 1) { + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 2; j += 1) { + rf_C__reduce_init[((2 * i) + ((64 * rf_k0) + j))] = 0.00000000f; + rf_C[((2 * i) + ((64 * rf_k0) + j))] = fma(A[((16 * i) + rf_k0)], B[((2 * rf_k0) + j)], rf_C[((2 * i) + ((64 * rf_k0) + j))]); + }; + }; + }; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 2; j += 1) { + C__reduce_init[((2 * i) + j)] = 0.00000000f; + for (int32_t k0 = 0; k0 < 16; k0 += 1) { + C[((2 * i) + j)] = (C[((2 * i) + j)] + rf_C[((2 * i) + ((64 * k0) + j))]); + }; + }; + }; + cinn_buffer_free((void*)(0), rf__C); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_inline1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + ir_sch.ComputeInline(block_b); + VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_inline1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_inline1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = fma(2.00000000f, A[((32 * i) + ((1024 * j) + k))], 2.00000000f); + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_inline2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_inline2", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + ir_sch.ComputeAt(block_b, loops[1]); + block_b = ir_sch.GetBlock("B"); + ir_sch.ComputeInline(block_b); + VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_inline2 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_inline2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = fma(2.00000000f, A[((1024 * i) + ((32 * j) + k))], 2.00000000f); + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +#ifdef CINN_WITH_CUDA +TEST(IrSchedule, compute_inline3) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_compute_inline3", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + ir_sch.ComputeInline(block_b); + VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_inline3 source code is :\n" << source_code; + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_compute_inline3(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 32768 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = (2.00000000f + (2.00000000f * A[((32 * i) + ((1024 * j) + k))])); + }; + }; + }; +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, compute_inline4) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultNVGPUTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + stages[B]->SetBuffer("local"); + + auto func = cinn::lang::LowerVec("test_compute_inline4", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + ir_sch.ComputeAt(block_b, loops[1]); + block_b = ir_sch.GetBlock("B"); + ir_sch.ComputeInline(block_b); + VLOG(1) << "After ComputeInline, IR is : " << ir_sch.GetModule().GetExprs().at(0); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenCUDA_Dev codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = codegen.GetSourceHeader() + R"ROC(__global__ +void test_compute_inline4(const float* __restrict__ A, float* __restrict__ C) +{ + float _B_temp_buffer [ 32768 ]; + float* B = _B_temp_buffer; + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((1024 * i) + ((32 * j) + k))] = (2.00000000f + (2.00000000f * A[((1024 * i) + ((32 * j) + k))])); + }; + }; + }; +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} +#endif + +TEST(IrSchedule, reverse_compute_inline1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(64); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return Expr(1.f) + A(i, j); }, "B"); + auto C = Compute( + {N, M}, [&](Var i, Var j) { return Expr(2.f) * B(j, i); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_c = ir_sch.GetBlock("C"); + ir_sch.ReverseComputeInline(block_c); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_inline1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_inline1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 64 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 64; j += 1) { + C[((32 * j) + i)] = fma(2.00000000f, A[((64 * i) + j)], 2.00000000f); + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, reverse_compute_inline2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return Expr(1.f) + A(i, j, k); }, "B"); + auto C = Compute( + {N, M, P}, [&](Var i, Var j, Var k) { return Expr(2.f) * B(j, i, k); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_c = ir_sch.GetBlock("C"); + ir_sch.ReverseComputeInline(block_c); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + VLOG(1) << "compute_inline1 source code is :\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void test_compute_inline1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 32; i += 1) { + for (int32_t j = 0; j < 32; j += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((32 * i) + ((1024 * j) + k))] = fma(2.00000000f, A[((1024 * i) + ((32 * j) + k))], 2.00000000f); + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, copytransform1) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_copytransform1", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_c = ir_sch.GetBlock("C"); + auto loops_c = ir_sch.GetLoops(block_c); + auto splited = ir_sch.Split(loops_c[1], {-1, 4}); + block_c = ir_sch.GetBlock("C"); + loops_c = ir_sch.GetLoops(block_c); + splited = ir_sch.Split(loops_c[0], {-1, 8}); + + auto block_b = ir_sch.GetBlock("B"); + block_c = ir_sch.GetBlock("C"); + + ir_sch.CopyTransformAndLoopInfo(block_b, block_c); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_copytransform1(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t i_0 = 0; i_0 < 8; i_0 += 1) { + for (int32_t j = 0; j < 8; j += 1) { + for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) { + for (int32_t k = 0; k < 32; k += 1) { + B[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (1.00000000f + A[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))]); + }; + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t i_0 = 0; i_0 < 8; i_0 += 1) { + for (int32_t j = 0; j < 8; j += 1) { + for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) { + for (int32_t k = 0; k < 32; k += 1) { + C[((8192 * i) + ((1024 * i_0) + ((128 * j) + ((32 * j_0) + k))))] = (2.00000000f * B[((256 * i) + ((32 * i_0) + ((4096 * j) + ((1024 * j_0) + k))))]); + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, copytransform2) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(64); + Expr P(128); + + Target target = common::DefaultHostTarget(); + + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B"); + auto C = Compute( + {M, M, P}, [&](Var i, Var j, Var k) { return B(i, j, k) * Expr(2.f); }, "C"); + + auto stages = CreateStages({A, B, C}); + + auto func = cinn::lang::LowerVec("test_copytransform2", stages, {A, C}, {}, {}, nullptr, target, true); + + auto ast_expr = func[0]->body; + std::vector vec_ast{ast_expr}; + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + + auto block_c = ir_sch.GetBlock("C"); + auto loops_c = ir_sch.GetLoops(block_c); + auto splited = ir_sch.Split(loops_c[1], {-1, 4}); + block_c = ir_sch.GetBlock("C"); + loops_c = ir_sch.GetLoops(block_c); + splited = ir_sch.Split(loops_c[0], {-1, 8}); + + auto block_b = ir_sch.GetBlock("B"); + block_c = ir_sch.GetBlock("C"); + ir_sch.CopyTransformAndLoopInfo(block_b, block_c); + Module::Builder builder("module1", target); + for (auto& i : func) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + + std::string target_code = R"ROC( +#include +#include + +void test_copytransform2(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 64, 128 }); + cinn_buffer_malloc((void*)(0), _C); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t i_0 = 0; i_0 < 8; i_0 += 1) { + for (int32_t j = 0; j < 64; j += 1) { + for (int32_t k = 0; k < 128; k += 1) { + B[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))] = (1.00000000f + A[((65536 * i) + ((8192 * i_0) + ((128 * j) + k)))]); + }; + }; + }; + }; + for (int32_t i = 0; i < 4; i += 1) { + for (int32_t i_0 = 0; i_0 < 8; i_0 += 1) { + for (int32_t j = 0; j < 8; j += 1) { + for (int32_t j_0 = 0; j_0 < 4; j_0 += 1) { + for (int32_t k = 0; k < 128; k += 1) { + C[((32768 * i) + ((4096 * i_0) + ((512 * j) + ((128 * j_0) + k))))] = (2.00000000f * B[((65536 * i) + ((8192 * i_0) + ((512 * j) + ((128 * j_0) + k))))]); + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, Annotate) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto funcs = cinn::lang::LowerVec( + "test_annotate", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + auto fused = ir_sch.Fuse("B", {0, 1}); + auto block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k1", int(64)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k2", bool(true)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k3", float(2.0)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k4", std::string("v4")); + std::string expected_expr = R"ROC({ + ScheduleBlock(root) + { + serial for (i_j_fused, 0, 1024) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind((i_j_fused / 32), (i_j_fused % 32)) + attrs(k1:64, k2:1, k3:2, k4:v4) + B[i0, i1] = A[i0, i1] + } + } + } +})ROC"; + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); +} + +TEST(IrSchedule, Unannotate) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto funcs = cinn::lang::LowerVec( + "test_unannotate", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + auto fused = ir_sch.Fuse("B", {0, 1}); + auto block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k1", int(64)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k2", bool(true)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k3", float(2.0)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k4", std::string("v4")); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k1"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k2"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k3"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k4"); + std::string expected_expr = R"ROC({ + ScheduleBlock(root) + { + serial for (i_j_fused, 0, 1024) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind((i_j_fused / 32), (i_j_fused % 32)) + B[i0, i1] = A[i0, i1] + } + } + } +})ROC"; + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); +} + +TEST(IrSchedule, ComplexIndices) { + Target target = common::DefaultHostTarget(); + ir::Expr M(32); + ir::Expr K(64); + + Placeholder A("A", {M, K}); + Var k(K.as_int32(), "reduce_axis_k"); + ir::Tensor B = Compute( + {M}, [&](Var i) { return ReduceSum(A(i, k), {k}); }, "B"); + + poly::StageMap stages = CreateStages({B}); + std::vector funcs = + lang::LowerVec("TestIrSchedule_ReduceSum", stages, {A, B}, {}, {}, nullptr, target, true); + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + VLOG(3) << "Lowered Expr:" << ir_sch.GetModule().GetExprs().front(); + + auto loops_b = ir_sch.GetLoops("B"); + CHECK_EQ(loops_b.size(), 2); + ir_sch.Split("B", 0, {8, -1}); + ir_sch.Split("B", 2, {32, -1}); // after first splited, loops size has added to 3 + VLOG(3) << "Splited Expr:" << ir_sch.GetModule().GetExprs().front(); + + CHECK_EQ(ir_sch.GetLoops("B").size(), 4); + ir_sch.Reorder("B", {2, 0, 3, 1}); + VLOG(3) << "Reordered Expr:\n" << ir_sch.GetModule().GetExprs().front(); + + auto block_b = ir_sch.GetBlock("B"); + auto a_cache = ir_sch.CacheRead(block_b, 1, "shared"); // actually the read_buffer A should be indexed by 0 + VLOG(3) << "CacheRead-A Expr:\n" << ir_sch.GetModule().GetExprs().front(); + + loops_b = ir_sch.GetLoops("B"); + ir_sch.ComputeAt(a_cache, loops_b[0]); + VLOG(3) << "A_cache-ComputeAt-B Expr:\n" << ir_sch.GetModule().GetExprs().front(); + + block_b = ir_sch.GetBlock("B"); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + VLOG(3) << "CacheWrite-B Expr:\n" << ir_sch.GetModule().GetExprs().front(); + + auto loops_b_cache = + ir_sch.GetLoops(b_cache.As()->schedule_block.As()->name); + block_b = ir_sch.GetBlock("B"); + ir_sch.ReverseComputeAt(block_b, loops_b_cache[1]); + VLOG(3) << "B-ReverseComputeAt-B_cache Expr:\n" << ir_sch.GetModule().GetExprs().front(); + + Module::Builder builder("module1", target); + for (auto& i : funcs) { + builder.AddFunction(i); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); + VLOG(3) << "scheduled source code:\n" << source_code; + + std::string target_code = R"ROC( +#include +#include + +void TestIrSchedule_ReduceSum(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + float* B__reduce_init = ((float*)(_B->memory)); + for (int32_t i = 0; i < 8; i += 1) { + for (int32_t i_0 = 0; i_0 < 4; i_0 += 1) { + B__reduce_init[((4 * i) + i_0)] = 0.00000000f; + }; + }; + for (int32_t reduce_axis_k = 0; reduce_axis_k < 32; reduce_axis_k += 1) { + for (int32_t ax0 = 0; ax0 < 32; ax0 += 1) { + for (int32_t ax1 = 0; ax1 < 2; ax1 += 1) { + A_shared_temp_buffer[((64 * ax0) + ((2 * reduce_axis_k) + ax1))] = A[((64 * ax0) + ((2 * reduce_axis_k) + ax1))]; + }; + }; + for (int32_t i = 0; i < 8; i += 1) { + for (int32_t reduce_axis_k_0 = 0; reduce_axis_k_0 < 2; reduce_axis_k_0 += 1) { + for (int32_t i_0 = 0; i_0 < 4; i_0 += 1) { + B_local_temp_buffer[((4 * i) + i_0)] = (B_local_temp_buffer[((4 * i) + i_0)] + A_shared_temp_buffer[((256 * i) + ((64 * i_0) + ((2 * reduce_axis_k) + reduce_axis_k_0)))]); + }; + }; + for (int32_t ax0_0 = 0; ax0_0 < 4; ax0_0 += 1) { + B[((4 * i) + ax0_0)] = B_local_temp_buffer[((4 * i) + ax0_0)]; + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code)); +} + +TEST(IrSchedule, SamplePerfectTile) { + Context::Global().ResetNameId(); + Expr M(1024); + Placeholder A("A", {M}); + auto B = Compute( + {M}, [&](Expr i) { return A(i) + 1; }, "B"); + poly::StageMap stages = CreateStages({A, B}); + + auto funcs = cinn::lang::LowerVec( + "test_sampleperfecttile", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + auto loops_b = ir_sch.GetLoops("B"); + std::vector result = ir_sch.SamplePerfectTile(loops_b[0], 3, 64); + ASSERT_EQ(result.size(), 3); +} + +TEST(IrSchedule, GetChildBlocks) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr K(32); + Placeholder A("A", {M, N, K}); + auto B = Compute( + {M, N, K}, [&A](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + auto C = Compute( + {M, N, K}, [&B](Var i, Var j, Var k) { return B(i, j, k); }, "C"); + auto funcs = cinn::lang::LowerVec( + "test_getchildblocks", CreateStages({A, B, C}), {A, C}, {}, {}, nullptr, common::DefaultHostTarget(), true); + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + + auto block_b = ir_sch.GetBlock("B"); + auto loops = ir_sch.GetLoops("C"); + ir_sch.ComputeAt(block_b, loops[1]); + loops = ir_sch.GetLoops("B"); + auto root_block = ir_sch.GetRootBlock(loops[1]); + + std::string expected_expr = R"ROC(ScheduleBlock(B) +{ + i0, i1, i2 = axis.bind(i, j, (0 + ax0)) + attrs(compute_at_extra_var:ax0) + B[i0, i1, i2] = A[i0, i1, i2] +}, ScheduleBlock(C) +{ + i0_0, i1_0, i2_0 = axis.bind(i, j, k) + C[i0_0, i1_0, i2_0] = B[i0_0, i1_0, i2_0] +})ROC"; + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetChildBlocks(root_block)), expected_expr); +} + +TEST(IrSchedule, SampleCategorical) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Expr P(32); + Placeholder A("A", {M, N, P}); + auto B = Compute( + {M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + poly::StageMap stages = CreateStages({A, B}); + std::vector decision; + auto funcs = cinn::lang::LowerVec( + "test_samplecategorical", stages, {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + Expr result = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}, {decision}); + ASSERT_EQ(result.type(), Int(32)); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/CMakeLists.txt b/paddle/cinn/backends/llvm/CMakeLists.txt new file mode 100755 index 0000000000000..f405b6b8801b6 --- /dev/null +++ b/paddle/cinn/backends/llvm/CMakeLists.txt @@ -0,0 +1,41 @@ +add_definitions(${LLVM_DEFINITIONS}) + +# generate cinn_runtime.ll file + +add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/cinn/backends/llvm/cinn_runtime_llvm_ir.h + COMMAND ${LLVM_PATH}/bin/clang++ -mavx2 -std=c++11 -masm=intel -S -emit-llvm -O3 ${PROJECT_SOURCE_DIR}/cinn/runtime/cinn_runtime.cc -I${PROJECT_SOURCE_DIR} -o ${CMAKE_BINARY_DIR}/cinn/runtime/cinn_runtime.ll + COMMAND ${PYTHON_EXECUTABLE} generate_runtime_llvm_ir.py ${CMAKE_BINARY_DIR}/cinn/runtime/cinn_runtime.ll ${CMAKE_BINARY_DIR}/cinn/backends/llvm/cinn_runtime_llvm_ir.h ${LLVM_PATH}/bin/llvm-config + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/cinn/backends/llvm + DEPENDS ${PROJECT_SOURCE_DIR}/cinn/runtime/cinn_runtime.cc ${PROJECT_SOURCE_DIR}/cinn/runtime/cinn_runtime.h + ) +add_custom_target(GEN_LLVM_RUNTIME_IR_HEADER ALL + DEPENDS ${CMAKE_BINARY_DIR}/cinn/backends/llvm/cinn_runtime_llvm_ir.h + ) + +set(srcs + llvm_util.cc + runtime_symbol_registry.cc + codegen_llvm.cc + codegen_x86.cc + simple_jit.cc + execution_engine.cc + llvm_optimizer.cc +) + + +cc_test(test_codegen_llvm SRCS codegen_llvm_test.cc DEPS cinncore) +#cc_test(test_execution_engine SRCS execution_engine_test.cc DEPS cinncore) +cc_test(test_codegen_x86 SRCS codegen_x86_test.cc DEPS cinncore) + +foreach(cpp ${srcs}) + set(cinnapi_src + "${cinnapi_src};cinn/backends/llvm/${cpp}" + CACHE INTERNAL "") +endforeach() + +file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + +foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/cinn/backends/llvm/codegen_llvm.cc b/paddle/cinn/backends/llvm/codegen_llvm.cc new file mode 100644 index 0000000000000..169fe3cfd40e3 --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_llvm.cc @@ -0,0 +1,1527 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/codegen_llvm.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/backends/extern_func_emitter.h" +#include "cinn/backends/extern_func_emitter_builtin.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/common/cas.h" +#include "cinn/common/type.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_verify.h" +#include "cinn/optim/var_mod_simplify.h" +#include "cinn/runtime/cinn_runtime.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/Alignment.h" + +namespace cinn { +namespace backends { + +using BinaryInstruction = llvm::Instruction::BinaryOps; +using common::bfloat16; +using common::float16; + +namespace { + +template +auto NodeToExpr(const T *node) { + std::ostringstream oss; + // oss << "\033[32m"; + oss << ir::Expr(const_cast(node)); + // oss << "\033[0m"; + return oss.str(); +} + +bool is_integral_type(common::Type t) { return t.is_int() || t.is_uint(); } + +bool is_floating_type(common::Type t) { return t.is_float(); } + +llvm::Value *EmitComparison(llvm::CmpInst::Predicate predicate, + llvm::Value *lhs, + llvm::Value *rhs, + llvm::IRBuilder<> *b) { + llvm::Value *comparison_result{nullptr}; + if (lhs->getType()->isIntegerTy()) { + comparison_result = b->CreateICmp(predicate, lhs, rhs); + } else { + comparison_result = b->CreateFCmp(predicate, lhs, rhs); + } + + return comparison_result; +} + +#define __IR_EMITTER_NOT_IMPLEMENTED(__op) CINN_NOT_IMPLEMENTED + +int NextPowerOfTwo(int x) { + for (int p2 = 1;; p2 *= 2) { + if (p2 >= x) { + return p2; + } + } + return 0; +} + +} // namespace + +CodeGenLLVM::CodeGenLLVM(llvm::Module *m, + llvm::IRBuilder<> *b, + const std::shared_ptr &symbol_table, + const Target &target) + : m_(m), b_(b), symbol_table_(symbol_table), target_(target) { + if (!symbol_table.get()) { + symbol_table_ = std::make_shared(); + } + symbol_table_->PushScope(); // Create a new scope by default. + + md_builder_ = std::make_unique(b_->getContext()); + md_tbaa_root_ = md_builder_->createTBAARoot("cinn-tbaa"); + md_tbaa_alias_set_ = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); + InitTarget(target_); +} + +CodeGenLLVM::~CodeGenLLVM() {} + +llvm::Value *CodeGenLLVM::EmitVectorSlice(llvm::Value *vec, int begin, int extent) { + int numel = llvm::dyn_cast(vec->getType())->getNumElements(); + if (extent == numel && begin == 0) return vec; + + CHECK(begin >= 0 && extent <= numel) << "Slicing out of bound!"; + + std::vector indices(extent); + for (int i = 0; i < extent; i++) { + llvm::Constant **v = &indices[i]; + if (begin + i >= 0 && begin + i < numel) { + *v = llvm::ConstantInt::get(b_->getInt32Ty(), begin + i); + } else { + *v = llvm::UndefValue::get(b_->getInt32Ty()); + } + } + return ShuffleVector(vec, vec, llvm::ConstantVector::get(std::move(indices))); +} + +llvm::Value *CodeGenLLVM::EmitVectorPad(llvm::Value *vec, int lanes) { +#if LLVM_VERSION_MAJOR <= 10 + llvm::Value *mask = llvm::UndefValue::get(llvm::VectorType::get(b_->getInt32Ty(), lanes)); +#else + llvm::Value *mask = + llvm::UndefValue::get(llvm::VectorType::get(b_->getInt32Ty(), llvm::ElementCount(lanes, false /*Scalable*/))); +#endif + int numel = llvm::dyn_cast(vec->getType())->getNumElements(); + + CHECK(numel <= lanes); + if (numel == lanes) return vec; + for (int i = 0; i < numel; i++) { + mask = + InsertElement(mask, llvm::ConstantInt::get(b_->getInt32Ty(), i), llvm::ConstantInt::get(b_->getInt32Ty(), i)); + } + + return ShuffleVector(vec, vec, mask); +} + +llvm::Value *CodeGenLLVM::EmitVectorConcat(std::vector vecs) { + int lanes = 0; + for (auto *v : vecs) { + lanes += llvm::dyn_cast(v->getType())->getNumElements(); + } + while (vecs.size() > 1) { + std::vector new_vecs; + for (size_t i = 0; i < vecs.size() - 1; i += 2) { + auto *lhs = vecs[i]; + auto *rhs = vecs[i + 1]; + const auto lhs_lanes = llvm::dyn_cast(lhs->getType())->getNumElements(); + const auto rhs_lanes = llvm::dyn_cast(rhs->getType())->getNumElements(); + if (lhs_lanes < rhs_lanes) { + lhs = EmitVectorPad(lhs, rhs_lanes); + } else if (lhs_lanes > rhs_lanes) { + rhs = EmitVectorPad(rhs, lhs_lanes); + } + + const auto shared_lanes = std::max(lhs_lanes, rhs_lanes); + std::vector mask(lhs_lanes + rhs_lanes); + std::iota(mask.begin(), std::next(mask.begin(), lhs_lanes), 0); + std::iota(std::next(mask.begin(), lhs_lanes), mask.end(), shared_lanes); + new_vecs.push_back(ShuffleVector(lhs, rhs, mask)); + } + if (vecs.size() % 2) { + new_vecs.push_back(vecs.back()); + } + + vecs = std::move(new_vecs); + } + + return EmitVectorSlice(vecs[0], 0, lanes); +} + +llvm::Value *CodeGenLLVM::EmitBinaryOp( + llvm::Value *lhs, llvm::Value *rhs, char opcode, bool is_integral, bool is_signed) { + llvm::Instruction::BinaryOps ops; + CHECK_EQ(lhs->getType(), rhs->getType()) + << "the types of operands of binary operation are mismatch" + << ", lhs[" << DumpToString(*lhs) << "] " << opcode << " rhs[" << DumpToString(*rhs) << "]" + << ", lhs_type[" << DumpToString(*lhs->getType()) << "], rhs_type[" << DumpToString(*rhs->getType()) << "]"; + switch (opcode) { + case '+': + ops = is_integral ? llvm::Instruction::BinaryOps::Add : llvm::Instruction::BinaryOps::FAdd; + break; + case '-': + ops = is_integral ? llvm::Instruction::BinaryOps::Sub : llvm::Instruction::BinaryOps::FSub; + break; + case '*': + ops = is_integral ? llvm::Instruction::BinaryOps::Mul : llvm::Instruction::BinaryOps::FMul; + break; + case '/': + ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SDiv : llvm::Instruction::BinaryOps::UDiv) + : llvm::Instruction::BinaryOps::FDiv; + break; + case '%': + ops = is_integral ? (is_signed ? llvm::Instruction::BinaryOps::SRem : llvm::Instruction::BinaryOps::URem) + : llvm::Instruction::BinaryOps::FRem; + break; + default: + return nullptr; + } + return BinOp(ops, lhs, rhs); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::IntImm *op) { + auto *type = b_->getIntNTy(op->type().bits()); + return llvm::ConstantInt::get(type, op->value, true); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::UIntImm *op) { + if (op->type().is_bool()) { + auto *type = b_->getInt1Ty(); + return llvm::ConstantInt::get(type, op->value, false); + } + auto *type = b_->getIntNTy(op->type().bits()); + return llvm::ConstantInt::get(type, op->value, false); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::FloatImm *op) { + if (op->type().is_float(64)) { + return llvm::ConstantFP::get(b_->getDoubleTy(), op->value); + } else if (op->type().is_float(32)) { + return llvm::ConstantFP::get(b_->getFloatTy(), op->value); + } else if (op->type().is_bfloat16()) { + return llvm::ConstantFP::get(b_->getBFloatTy(), op->value); + } else if (op->type().is_float16()) { + return llvm::ConstantFP::get(b_->getHalfTy(), op->value); + } else { + LOG(FATAL) << "illegal float type."; + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::LLVMGenGlobalStringVar(const std::string &data) { return b_->CreateGlobalStringPtr(data); } + +llvm::Value *CodeGenLLVM::Visit(const ir::StringImm *op) { return LLVMGenGlobalStringVar(op->value); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Add *op) { + return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '+', is_integral_type(op->type())); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Sub *op) { + return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '-', is_integral_type(op->type())); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Mul *op) { + auto *lhs = Visit(&op->a()); + auto *rhs = Visit(&op->b()); + return EmitBinaryOp(lhs, rhs, '*', is_integral_type(op->type())); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Div *op) { + return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '/', is_integral_type(op->type())); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Mod *op) { + return EmitBinaryOp(Visit(&op->a()), Visit(&op->b()), '%', is_integral_type(op->type())); +} + +#define __IR_EMITTER_DEFINE_CMP_VISITOR(__sop, __uop, __fop) \ + auto *lhs = Visit(&op->a()); \ + auto *rhs = Visit(&op->b()); \ + CHECK(op->a().type() == op->b().type()); \ + llvm::CmpInst::Predicate predicate; \ + if (op->a().type().is_int()) { \ + predicate = llvm::CmpInst::ICMP_##__sop; \ + } else if (op->a().type().is_uint()) { \ + predicate = llvm::CmpInst::ICMP_##__uop; \ + } else /*float*/ { \ + predicate = llvm::CmpInst::FCMP_##__fop; \ + } \ + return EmitComparison(predicate, lhs, rhs, b_) + +llvm::Value *CodeGenLLVM::Visit(const ir::EQ *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(EQ, EQ, OEQ); } + +llvm::Value *CodeGenLLVM::Visit(const ir::NE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(NE, NE, ONE); } + +llvm::Value *CodeGenLLVM::Visit(const ir::LT *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SLT, ULT, OLT); } + +llvm::Value *CodeGenLLVM::Visit(const ir::LE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SLE, ULE, OLE); } + +llvm::Value *CodeGenLLVM::Visit(const ir::GT *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SGT, UGT, OGT); } + +llvm::Value *CodeGenLLVM::Visit(const ir::GE *op) { __IR_EMITTER_DEFINE_CMP_VISITOR(SGE, UGE, OGE); } + +#undef __IR_EMITTER_DEFINE_CMP_VISITOR + +llvm::Value *CodeGenLLVM::Visit(const ir::And *op) { return And(Visit(&op->a()), Visit(&op->b())); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Or *op) { return Or(Visit(&op->a()), Visit(&op->b())); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Min *op) { + auto *lhs = Visit(&op->a()); + auto *rhs = Visit(&op->b()); + + llvm::Value *p{nullptr}; + if (op->type().is_int()) { + p = ICmpSLT(lhs, rhs); + } else if (op->type().is_uint()) { + p = ICmpULT(lhs, rhs); + } else /*float*/ { + p = FCmpOLT(lhs, rhs); + } + + return Select(p, lhs, rhs); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Max *op) { + auto *lhs = Visit(&op->a()); + auto *rhs = Visit(&op->b()); + + llvm::Value *p = nullptr; + if (op->type().is_int()) { + p = ICmpSGT(lhs, rhs); + } else if (op->type().is_uint()) { + p = ICmpUGT(lhs, rhs); + } else /*float*/ { + p = FCmpOGT(lhs, rhs); + } + + return Select(p, lhs, rhs); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Minus *op) { + auto *v = Visit(&op->v()); + return (op->type().is_int() || op->type().is_uint()) ? Neg(v) : FNeg(v); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { return Not(Visit(&op->v())); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) { + auto from = op->v().type(); + auto to = op->type(); + + llvm::Type *source = CinnTypeToLLVMType(from, m_); + llvm::Type *target = CinnTypeToLLVMType(to, m_); + CHECK(source) << "source ir type is null"; + CHECK(target) << "target ir type is null"; + + llvm::Value *value = Visit(&op->v()); + CHECK(value) << "value is null"; + + // pod_value_t cast to a value. + if (op->v().type().is_customized_type() && + op->v().type().customized_type() == common::customized_type::kpod_value_t) { // pod_value_t operator + llvm::Function *callee{}; + if (op->type().is_bool()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_bool); + } else if (op->type().is_int(8)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int8); + } else if (op->type().is_int(16)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int16); + } else if (op->type().is_int(32)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int32); + } else if (op->type().is_int(64)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int64); + } else if (op->type().is_uint(8)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint8); + } else if (op->type().is_uint(16)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint16); + } else if (op->type().is_uint(32)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint32); + } else if (op->type().is_uint(64)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint64); + } else if (op->type().is_float(32)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_float); + } else if (op->type().is_float(64)) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_double); + } else if (op->type().is_bfloat16()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_bfloat16); + } else if (op->type().is_float16()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_float16); + } else if (op->type() == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_void_p); + } else if (op->type() == type_of() || op->type() == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_buffer_p); + } else { + LOG(ERROR) << "can't cast cinn_pod_value_t to " << op->type(); + CINN_NOT_IMPLEMENTED + } + + CHECK(callee); + CHECK(op->v().as_var()) << "argument to the intrinsic function " + "cinn_pod_value_to_x should be a Var"; + value = GetVar(op->v().as_var()->name); + return Call(callee, std::vector({value}), "pod_value_cast"); + } + + do { + if (value->getType() == target) break; + + if (to.is_cpp_handle() || to.is_cpp_handle2()) { + value = BitCast(value, target, "cast_to_cpp_handle"); + break; + } + + if (to.is_bool()) { + if (from.is_float()) { + llvm::Constant *zero = llvm::ConstantFP::get(source, 0.); + value = FCmpONE(value, zero); + } else { + llvm::Constant *zero = llvm::ConstantInt::get(source, 0); + value = ICmpNE(value, zero); + } + break; + } + + if (from.is_float() == false && to.is_float() == false) { + value = IntCast(value, target, from.is_int()); + break; + } + + if (from.is_float() && to.is_int()) { + value = FPToSI(value, target); + break; + } + + if (from.is_float() && to.is_uint()) { + value = FPToUI(value, target); + if (to.bits() < 8) { + value = IntCast(value, target, false); + } + break; + } + + if (from.is_int() && to.is_float()) { + value = SIToFP(value, target); + break; + } + + if (from.is_uint() && to.is_float()) { + value = UIToFP(value, target); + break; + } + + CHECK(from.is_float() && to.is_float()); + value = FPCast(value, target); + } while (false); + + return value; +} + +llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) { + SymbolTableGuard symbol_table_guard(*symbol_table_); + + do { + break; + llvm::BasicBlock *preheader_bb = b_->GetInsertBlock(); + auto *for_begin = llvm::BasicBlock::Create(b_->getContext(), "for_begin", b_->GetInsertBlock()->getParent()); + auto *for_body = llvm::BasicBlock::Create(b_->getContext(), "for_body", b_->GetInsertBlock()->getParent()); + auto *for_end = llvm::BasicBlock::Create(b_->getContext(), "for_end", b_->GetInsertBlock()->getParent()); + + Br(for_begin); + b_->SetInsertPoint(for_begin); + + auto *begin = Visit(&op->min); + auto *loop_value = PHI(begin->getType(), 2); + loop_value->addIncoming(begin, preheader_bb); + + llvm::Value *old_var = GetVar(op->loop_var->name); + SetVar(op->loop_var->name, loop_value); + auto *end = Visit(&op->extent); + CondBr(ICmpSLT(loop_value, end), for_body, for_end); + b_->SetInsertPoint(for_body); + Visit(&op->body); + + if (old_var) { + SetVar(op->loop_var->name, old_var); + } else { + symbol_table_->Erase(op->loop_var->name); + } + + auto loop_next = Add(loop_value, llvm::ConstantInt::get(b_->getInt32Ty(), stride), "indvar.inc", true, true); + loop_value->addIncoming(loop_next, b_->GetInsertBlock()); + + Br(for_begin); + b_->SetInsertPoint(for_end); + + return nullptr; + // llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr, op->loop_var->name); + // loop_var->setAlignment(llvm::Align(4)); + // SetVar(op->loop_var->name, loop_var); + } while (false); + + //////////////////////////////////// + llvm::BasicBlock *preheader_bb = b_->GetInsertBlock(); + llvm::BasicBlock *exit_bb = nullptr; + + llvm::BasicBlock::iterator insert_point = b_->GetInsertPoint(); + + if (insert_point == preheader_bb->end()) { + CHECK(!preheader_bb->getTerminator()); + exit_bb = llvm::BasicBlock::Create(b_->getContext(), "loop_exit", b_->GetInsertBlock()->getParent(), nullptr); + } else { + CHECK(preheader_bb->getTerminator()); + exit_bb = preheader_bb->splitBasicBlock(insert_point, "loop_exit"); + preheader_bb->getTerminator()->eraseFromParent(); + } + + llvm::BasicBlock *header_bb = + llvm::BasicBlock::Create(b_->getContext(), "loop_header", b_->GetInsertBlock()->getParent(), nullptr); + llvm::BasicBlock *body_bb = + llvm::BasicBlock::Create(b_->getContext(), "loop_body", b_->GetInsertBlock()->getParent(), nullptr); + + llvm::Function *func = preheader_bb->getParent(); + b_->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); + + llvm::Value *old_var = GetVar(op->loop_var->name); + // loop iterator + llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr, op->loop_var->name); + loop_var->setAlignment(llvm::Align(4)); + SetVar(op->loop_var->name, loop_var); + + b_->SetInsertPoint(preheader_bb); + llvm::Value *start_index = Visit(&op->min); + llvm::Value *end_index = Visit(&op->extent); + Store(start_index, loop_var); + CHECK(!preheader_bb->getTerminator()); + Br(header_bb); + + // loop_header + b_->SetInsertPoint(header_bb); + llvm::Value *indvar = Load(loop_var, "indvar"); + llvm::Value *exit_cond = ICmpSGE(indvar, end_index); + CondBr(/*Cond=*/exit_cond, + /*True=*/exit_bb, + /*False=*/body_bb); + + // loop_body + b_->SetInsertPoint(body_bb); + llvm::Value *step = llvm::ConstantInt::get(b_->getInt32Ty(), stride); + + Visit(&op->body); + llvm::Value *indvar_inc = Add(indvar, + step, + "indvar.inc", + /*HasNUW=*/true, + /*HasNSW=*/true); + Store(indvar_inc, loop_var); + llvm::BranchInst *back_branch = Br(header_bb); + + // Add loop metadata + decltype(auto) ctx = b_->getContext(); + std::vector loop_metadata; + auto temp_node = llvm::MDNode::getTemporary(ctx, llvm::None); + loop_metadata.push_back(temp_node.get()); + + // TODO(fc500110): Loop vectorize + // auto *vectorization = op->metadata.vectorization ? b_->getTrue() : b_->getFalse(); + // loop_metadata.push_back(llvm::MDNode::get( + // ctx, {llvm::MDString::get(ctx, "llvm.loop.vectorize.enable"), + // llvm::ConstantAsMetadata::get(b_->getFalse())})); + + // Loop unroll + std::string llvm_unroll_metadata{"llvm.loop.unroll."}; + switch (op->metadata.unroll_mode) { + case ir::LLVMForLoopMeta::FullyUnroll: + llvm_unroll_metadata += "full"; + break; + case ir::LLVMForLoopMeta::NoUnroll: + llvm_unroll_metadata += "disable"; + break; + default: + llvm_unroll_metadata += "enable"; + } + + /* + loop_metadata.push_back(llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, llvm_unroll_metadata)})); + auto loop_id = llvm::MDNode::get(ctx, loop_metadata); + loop_id->replaceOperandWith(0, loop_id); + back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); + */ + + if (old_var) { + SetVar(op->loop_var->name, old_var); + } else { + symbol_table_->Erase(op->loop_var->name); + } + + b_->SetInsertPoint(exit_bb); + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::For *op) { return CreateSerialFor(op); } + +llvm::Value *CodeGenLLVM::Visit(const ir::PolyFor *op) { + CINN_NOT_IMPLEMENTED + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Select *op) { + return Select(Visit(&op->condition), Visit(&op->true_value), Visit(&op->false_value)); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::IfThenElse *op) { + SymbolTableGuard symbol_table_guard(*symbol_table_); + + bool emit_else = op->false_case.defined(); + + auto &ll_ctx = b_->getContext(); + auto *ll_function = b_->GetInsertBlock()->getParent(); + + llvm::Value *cond = Visit(&op->condition); + llvm::BasicBlock *then_block = llvm::BasicBlock::Create(ll_ctx, "if-then", ll_function); + llvm::BasicBlock *end_block = llvm::BasicBlock::Create(ll_ctx, "if-end", ll_function); + + if (op->false_case.defined()) { + llvm::BasicBlock *else_block = llvm::BasicBlock::Create(ll_ctx, "if-else", ll_function); + CondBr(cond, then_block, else_block); + + // true case + b_->SetInsertPoint(then_block); + Visit(&op->true_case); + Br(end_block); + + // false case + b_->SetInsertPoint(else_block); + Visit(&op->false_case); + Br(end_block); + } else { + CondBr(cond, then_block, end_block); + b_->SetInsertPoint(then_block); + Visit(&op->true_case); + Br(end_block); + } + b_->SetInsertPoint(end_block); + + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Block *op) { + // Create a new scope holding the temporary variables. + SymbolTableGuard symbol_table_guard(*symbol_table_); + + llvm::Value *ret = nullptr; + + llvm::BasicBlock *block = + llvm::BasicBlock::Create(b_->getContext(), "block", b_->GetInsertBlock()->getParent(), nullptr); + + Br(block); + b_->SetInsertPoint(block); + + for (const auto &expr : op->stmts) { + ret = Visit(&expr); + } + + return ret; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::PrimitiveNode *) { CINN_NOT_IMPLEMENTED return nullptr; } +llvm::Value *CodeGenLLVM::Visit(const ir::_BufferRange_ *) { CINN_NOT_IMPLEMENTED return nullptr; } +llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlock *) { CINN_NOT_IMPLEMENTED return nullptr; } +llvm::Value *CodeGenLLVM::Visit(const ir::ScheduleBlockRealize *) { CINN_NOT_IMPLEMENTED return nullptr; } + +llvm::Value *CodeGenLLVM::Visit(const ir::Call *op) { + if (op->name == runtime::intrinsic::debug_log_repr) { + return EmitCall_debug_info(op); + } else if (op->is_extern_call()) { + auto emitter_id = ExternFuncID{backend_llvm_host, op->name.c_str()}; + const auto &fn_name = ExternFunctionEmitterRegistry::Global().Lookup(emitter_id); + if (!fn_name.empty()) { + ExternFunctionLLVMEmitter emitter(fn_name); + emitter.BindCodeGen(this); + emitter.Emit(op); + return extern_func_emit_res_; + } + } + + llvm::Function *callee = m_->getFunction(op->name); + CHECK(callee) << "Unknown function referenced. [" << op->name << "]"; + + std::vector args; + for (const auto &e : op->read_args) { + auto *arg = Visit(&e); + CHECK(arg) << "argument " << e << " is null"; + args.push_back(arg); + } + for (const auto &e : op->write_args) { + auto *arg = Visit(&e); + CHECK(arg) << "argument " << e << " is null"; + args.push_back(arg); + } + + if (op->is_cinn_call()) { + auto arg = ir::intrinsics::GetAddr::Make(op->read_args[0]); + args[0] = Visit(&arg); + args[0] = BitCast(args[0], ll_void_p_ty(), "cast_to_void_p"); + } + + return Call(callee, std::move(args)); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) { + { + Expr body_to_verify(&Reference(op)); + ir::IrVerify(body_to_verify); + } + + for (auto &fn : op->functions) { + VLOG(1) << "JIT Linking function [" << fn.As()->name << "]"; + ir::Expr fn_expr(fn); + + auto fnll = Visit(&fn_expr); + + VLOG(5) << "fn llvm:\n" << DumpToString(*fnll); + } +} + +llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) { + llvm::Value *value = GetVar(op->name, false); + llvm::Value *result{}; + CHECK(value) << "ir::_Var_[" << op->name << "]: value is null"; + // TODO(fc500110) hard coding + if (LLVM_WillVarLowerAsPointer(op->name)) { + result = value; + } else if (value->getType()->isPointerTy()) { + result = Load(value, op->name + "_load"); + } else { + result = value; + } + + return result; +} + +void CodeGenLLVM::Scalarize(const Expr &e, std::function flambda) { + if (const ir::Ramp *ramp = e.As()) { + for (int i = 0; i < ramp->type().lanes(); ++i) { + Expr offset = ramp->base + (ramp->stride * i); + VLOG(3) << "offset: " << offset; + flambda(i, Visit(&offset)); + } + } else { + llvm::Value *value = Visit(&e); + for (int i = 0; i < e->type().lanes(); ++i) { + flambda(i, b_->CreateExtractElement(value, i)); + } + } +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Load *op) { + llvm::Value *array{nullptr}; + bool is_alias{false}; + if (auto *tensor_op = op->tensor.As()) { + array = GetVar(tensor_op->name); + } else if (auto *var_op = op->tensor.As()) { + array = GetVar(var_op->name); + is_alias = alias_vars_.count(const_cast(var_op)); + } else { + array = Visit(&op->tensor); + } + CHECK(array) << "fail to Visit Load node: " << Expr(const_cast(op)); + + ir::Expr index = op->index(); + if (index.type().lanes() <= 1) { + std::vector indices; + indices.push_back(Visit(&index)); + + // auto load_inst = Load(InBoundsGEP(array, std::move(indices))); + auto *load_inst = AlignedLoad(InBoundsGEP(array, std::move(indices)), llvm::MaybeAlign()); + /* + if (is_alias) { + llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); + load_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + } + */ + if (auto *load_tensor = op->tensor.as_tensor()) { + AddTbaaMetadata(load_inst, load_tensor->name, op->index()); + } + + { + int alignment = op->type().bits(); + alignment = 8; + CHECK_GT(alignment, 0); + load_inst->setAlignment(llvm::Align(std::min(alignment, 8))); + } + + // TODO(fc500110): tbaa AliasAnalysis + // auto md_tbaa_root = md_builder_->createTBAARoot("cinn-tbaa"); + // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", md_tbaa_root); + // llvm::MDNode *meta = md_tbaa_alias_set; + // load_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + return load_inst; + } else { // vector load + Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); + llvm::Value *buffer = Visit(&op->tensor); + if (dense_strided_ramp.defined()) { + CHECK(op->type().is_vector()); + return DenseVectorLoad(op); + } + // scalarize load + Type type = op->type(); + int alignment = type.bits() / 8; + llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); + auto flambda = [&](int i, llvm::Value *index) { + auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index); + llvm::LoadInst *load_inst = b_->CreateAlignedLoad(ptr, llvm::Align(alignment), "load_vec"); + ret = b_->CreateInsertElement(ret, load_inst, ll_const_int32(i)); + if (auto *load_tensor = op->tensor.as_tensor()) { + AddTbaaMetadata(load_inst, load_tensor->name, op->index()); + } + }; + Scalarize(op->index(), flambda); + return ret; + } +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Store *op) { + llvm::Value *array{nullptr}; + bool is_alias{false}; + if (auto *tensor_op = op->tensor.As()) { + array = GetVar(tensor_op->name); + } else if (auto *var_op = op->tensor.As()) { + array = GetVar(var_op->name); + is_alias = alias_vars_.count(const_cast(var_op)); + } + CHECK(array) << "array is null"; + + ir::Expr index = op->index(); + + if (op->type().is_scalar()) { + std::vector indices; + indices.push_back(Visit(&index)); + + // auto *store_inst = Store(Visit(&op->value), InBoundsGEP(array, std::move(indices))); + auto *store_inst = AlignedStore(Visit(&op->value), InBoundsGEP(array, std::move(indices)), llvm::MaybeAlign()); + /* + if (is_alias) { + llvm::MDNode *meta = md_builder_->createTBAANode("cinn-alias", md_tbaa_root_); + store_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + } + */ + { + int alignment = op->type().bits(); + alignment = 8; + CHECK_GT(alignment, 0); + store_inst->setAlignment(llvm::Align(std::min(alignment, 8))); + } + // TODO(fc500110): tbaa AliasAnalysis + // auto md_tbaa_root = md_builder_->createTBAARoot("cinn-tbaa"); + // auto md_tbaa_alias_set = md_builder_->createTBAANode("cinn-alias", md_tbaa_root); + // llvm::MDNode *meta = md_tbaa_alias_set; + // store_inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); + AddTbaaMetadata(store_inst, op->tensor.as_tensor()->name, op->index()); + return store_inst; + } else { // vector store + Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); + auto ramp_expr = op->index(); + auto *ramp = index.As(); + auto *buffer = Visit(&op->tensor); + auto *value = Visit(&op->value); + + if (dense_strided_ramp.defined()) { // stride 1 + int total_lanes = op->type().lanes(); + int step = naive_vec_alignment_ / op->type().ElementOf().bits(); + + // fit the total_lanes in native_lanes(split into multiple native steps) + for (int offset = 0; offset < total_lanes; offset += total_lanes) { + int lanes = total_lanes; + Expr base = common::AutoSimplify(ramp->base + offset); + optim::VarModSimplify(&base); + auto *ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&base)); + auto *vtype = llvm::VectorType::get(CinnTypeToLLVMType(op->type().ElementOf(), m_, true), + llvm::ElementCount(lanes, false /*Scalable*/)) + ->getPointerTo(); + int alignment = std::max(op->type().ElementOf().bits() / 8, 1); + llvm::StoreInst *inst = + b_->CreateAlignedStore(CreateVecSlice(value, offset, lanes), b_->CreatePointerCast(ptr, vtype), alignment); + AddTbaaMetadata(inst, op->tensor.as_tensor()->name, base); + return inst; + } + } + // scalarize store + Type type = op->type(); + int alignment = type.bits() / 8; + llvm::Value *ret = llvm::UndefValue::get(CinnTypeToLLVMType(type, m_, true)); + auto flambda = [&](int i, llvm::Value *index) { + auto *ptr = CreateBufferPtr(type.ElementOf(), buffer, index); + llvm::StoreInst *store_inst = + b_->CreateAlignedStore(b_->CreateExtractElement(value, i), ptr, llvm::Align(alignment), "store_vec"); + ret = b_->CreateInsertElement(ret, store_inst, ll_const_int32(i)); + if (auto *store_tensor = op->tensor.as_tensor()) { + AddTbaaMetadata(store_inst, store_tensor->name, op->index()); + } + }; + Scalarize(op->index(), flambda); + return ret; + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Alloc *op) { + auto *buffer_op = op->destination.As(); + auto *buffer = GetVar(buffer_op->name); + CHECK(buffer); + + return buffer; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Free *op) { + auto *buffer_op = op->destination.As(); + CHECK(symbol_table_->Lookup(buffer_op->name)); + symbol_table_->Erase(buffer_op->name); + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::_Buffer_ *op) { return GetVar(op->name); } + +llvm::Value *CodeGenLLVM::Visit(const ir::_Tensor_ *op) { + return GetVar(op->name); + auto *buffer_op = op->buffer.As(); + if (symbol_table_->Lookup(buffer_op->name)) { + return Visit(buffer_op); + } + + return SetVar(buffer_op->name, Visit(buffer_op)); +} + +template ::value, int> = 0> +void appendBody(std::vector &new_body, T &&v) { + new_body.push_back(v); +} + +template ::value, int> = 1> +void appendBody(std::vector &new_body, T &&v) { + new_body.insert(new_body.end(), v.begin(), v.end()); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::_LoweredFunc_ *op) { + auto init_function_state = [this]() { alias_vars_.clear(); }; + init_function_state(); + + CHECK_EQ(op->alloc_output_buffer_exprs.size(), op->dealloc_output_buffer_exprs.size()) + << "the count of allocation and deallocation expressions is not match"; + + std::vector new_body; + auto create_temp_buffers = op->PrepareCreateTempBufferExprs(); + auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs(); + auto dealloca_temp_buffers = op->PrepareDeallocTempBufferExprs(); + + appendBody(new_body, op->argument_prepare_exprs); + appendBody(new_body, create_temp_buffers); + appendBody(new_body, alloca_temp_buffers); + appendBody(new_body, op->alloc_output_buffer_exprs); + appendBody(new_body, op->buffer_data_cast_exprs); + appendBody(new_body, op->body); + appendBody(new_body, dealloca_temp_buffers); + appendBody(new_body, op->dealloc_output_buffer_exprs); + + ir::Expr function_body = ir::Block::Make(new_body); + + // Emit Function + std::vector arg_types = {b_->getInt8PtrTy(), b_->getInt32Ty()}; + + llvm::FunctionType *function_type = llvm::FunctionType::get( + /*Result=*/b_->getVoidTy(), + /*Params=*/std::move(arg_types), + /*isVarArg=*/false); + CHECK(m_->getFunction(op->name) == nullptr) << "function[" << op->name << "] exists"; + + f_ = llvm::Function::Create( + /*FunctionType=*/function_type, + /*LinkageTypes=*/llvm::Function::ExternalLinkage, + /*Name=*/op->name, + /*Module=*/m_); + f_->setCallingConv(llvm::CallingConv::C); + f_->setHasUWTable(); // GDB + + std::vector args; + args.reserve(f_->arg_size()); + std::transform( + f_->arg_begin(), f_->arg_end(), std::back_inserter(args), [](auto &arg) { return std::addressof(arg); }); + + llvm::BasicBlock *entry = llvm::BasicBlock::Create( + /*Context=*/b_->getContext(), + /*Name=*/"entry", + /*Parent=*/f_, + /*InsertBefore=*/nullptr); + + SetVar("_args", args[0]); + b_->SetInsertPoint(entry); + Visit(&function_body); + symbol_table_->Erase("_args"); + RetVoid(); + return f_; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Let *op) { + CHECK(op->type().valid()); + auto name = op->symbol.As()->name; + if (op->symbol.As()->type().is_cpp_handle()) { + alias_vars_.insert(const_cast(op->symbol.As())); + } + if (op->body.defined()) { + SetVar(name, Visit(&op->body)); + } else { + llvm::AllocaInst *inst = Alloca(CinnTypeToLLVMType(op->type(), m_), nullptr, name); + auto get_align = [](int n) { + int i{0}, r{1}; + while (n > r) { + r *= 2; + ++i; + } + return r / 8; + }; + int align_bits = std::max(op->type().bits(), 8); + int align = get_align(align_bits); + inst->setAlignment(llvm::Align(align)); + SetVar(name, inst); + } + + return GetVar(name); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Reduce *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Ramp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Broadcast *op) { +#if LLVM_VERSION_MAJOR >= 11 + const llvm::ElementCount elem_count(op->lanes, /*scalable*/ false); +#else + const int elem_count = op->lanes; +#endif + llvm::Value *value = Visit(&op->value); + llvm::Constant *undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), elem_count)); + llvm::Constant *zero = llvm::ConstantInt::get(ll_int32_ty(), 0); + value = b_->CreateInsertElement(undef, value, zero, "broadcast"); + llvm::Constant *zeros = llvm::ConstantVector::getSplat(elem_count, zero); + return b_->CreateShuffleVector(value, undef, zeros, "broadcast_shuffle"); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); } + +llvm::Value *CodeGenLLVM::Visit(const ir::Product *op) { + auto size = op->operands().size(); + if (size == 0) return nullptr; + + llvm::Value *ret = Visit(&op->operand(0)); + for (int i = 1; i < size; i++) { + llvm::Value *v = Visit(&op->operand(i)); + if (is_integral_type(op->type())) { + ret = Mul(ret, v); + } else { + ret = FMul(ret, v); + } + } + + return ret; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::Sum *op) { + auto size = op->operands().size(); + if (size == 0) return nullptr; + + llvm::Value *ret = Visit(&op->operand(0)); + for (int i = 1; i < size; i++) { + llvm::Value *v = Visit(&op->operand(i)); + if (is_integral_type(op->type())) { + ret = Add(ret, v); + } else { // float + ret = FAdd(ret, v); + } + } + + return ret; +} + +#undef __IR_EMITTER_CINN_NOT_IMPLEMENTED + +void CodeGenLLVM::Compile(const ir::Module &module) { Visit(module.self()); } + +llvm::Value *CodeGenLLVM::EmitCall_buffer_malloc(const ir::Call *op) { return nullptr; } + +llvm::Value *CodeGenLLVM::EmitCall_get_address(const ir::Call *op) { + if (auto *read_var = op->read_args.front().as_var()) { + return GetVar(read_var->name); + } + + if (auto *read_buf = op->read_args.front().as_buffer()) { + return GetVar(read_buf->name); + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::EmitCall_debug_info(const ir::Call *op) { + auto callee = m_->getFunction(runtime::intrinsic::debug_log_repr); + CHECK_GE(op->read_args.size(), 1UL); + std::vector args; + for (auto &arg : op->read_args) { + args.push_back(Visit(&arg)); + } + return Call(callee, args, "call debug_info"); +} + +llvm::Value *CodeGenLLVM::GetVar(const std::string &name, bool lazy) { + auto symbol = symbol_table_->Lookup(name); + if (!lazy) { + CHECK(symbol) << "No var [" << name << "] found"; + } + return symbol; +} + +llvm::Value *CodeGenLLVM::SetVar(const std::string &name, llvm::Value *val) { + symbol_table_->Insert(name, val); + CHECK(GetVar(name)); + return val; +} + +llvm::FunctionType *CodeGenLLVM::GenFunctionTypeFromCinnFunction(const ir::_LoweredFunc_ *func, bool with_buffer_type) { + auto func_ret_type = CinnTypeToLLVMType(Void(), m_); + std::vector arg_types; + for (auto &arg : func->args) { + if (arg.is_buffer() && arg.is_var()) { + alias_vars_.insert(arg.var_arg().get()); + } + if (arg.is_var()) { + arg_types.push_back(CinnTypeToLLVMType(arg.var_arg()->type(), m_)); + } else if (arg.is_buffer()) { + if (with_buffer_type) { + arg_types.push_back(ll_cinn_buffer_p_ty()); + } else { + arg_types.push_back(CinnTypeToLLVMType(arg.buffer_arg()->type(), m_)); + } + } + } + + return llvm::FunctionType::get(func_ret_type, arg_types, false); +} + +llvm::Value *CodeGenLLVM::DenseVectorLoad(const ir::Load *op) { + auto index = op->index(); + auto *ramp = index.As(); + CHECK(ramp); + + int load_lanes = op->type().lanes(); + int native_lanes = naive_vec_alignment_ / op->type().bits(); + + std::vector slices; + + llvm::Value *buffer = Visit(&op->tensor); + buffer->setName("buffer"); + + for (int i = 0; i < load_lanes; i += load_lanes) { + int slice_lanes = load_lanes; + auto slice_base = common::AutoSimplify(ramp->base + i); + optim::VarModSimplify(&slice_base); + auto slide_stride = Expr(1); + auto slide_index = slice_base; + +#if LLVM_VERSION_MAJOR >= 11 + const llvm::ElementCount elem_count(slice_lanes, /*scalable*/ false); +#else + const int elem_count = slice_lanes; +#endif + + llvm::Type *slice_type = llvm::VectorType::get(CinnTypeToLLVMType(op->type().ElementOf(), m_, true), elem_count); + + llvm::Value *elt_ptr = CreateBufferPtr(op->type().ElementOf(), buffer, Visit(&slice_base)); + llvm::Value *vec_ptr = b_->CreatePointerCast(elt_ptr, slice_type->getPointerTo(), "get_vec_ptr"); + + int alignment = std::max(op->type().ElementOf().bits() / 8, 1); + + llvm::Instruction *load_inst = b_->CreateAlignedLoad(vec_ptr, llvm::Align(alignment), "load_vec"); + AddTbaaMetadata(load_inst, op->tensor.as_tensor()->name, op->index()); + + slices.push_back(load_inst); + } + + CHECK_EQ(slices.size(), 1UL); + + return slices[0]; +} + +llvm::Value *CodeGenLLVM::CreateBufferVecPtr(Type t, llvm::Value *buffer, llvm::Value *index) { + CHECK_GT(t.lanes(), 1) << "type is not a vector type: " << t; + llvm::PointerType *btype = llvm::dyn_cast(buffer->getType()); + CHECK(btype); + llvm::PointerType *ptype = CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); + if (btype != ptype) { + buffer = b_->CreatePointerCast(buffer, ptype); + } + return b_->CreateInBoundsGEP(buffer, index); +} + +llvm::Value *CodeGenLLVM::CreateBufferPtr(Type t, llvm::Value *buffer, llvm::Value *index) { + CHECK_EQ(t.lanes(), 1); + auto *btype = llvm::dyn_cast(buffer->getType()); + CHECK(btype); + auto *ptype = CinnTypeToLLVMType(t, m_)->getPointerTo(btype->getAddressSpace()); + CHECK(ptype); + if (btype != ptype) { + buffer = b_->CreatePointerCast(buffer, ptype, "pointer_cast"); + } + return b_->CreateInBoundsGEP(buffer, index, "buffer_ptr"); +} + +llvm::Value *CodeGenLLVM::CreateVecSlice(llvm::Value *vec, int begin, int lanes) { + int total_lanes = llvm::dyn_cast(vec->getType())->getNumElements(); + CHECK_LE(begin + lanes, total_lanes); + if (lanes == total_lanes && begin == 0) return vec; // full slice + std::vector indices; + for (int i = 0; i < lanes; ++i) { + indices.push_back(ll_const_int32(begin + i)); + } + llvm::Constant *undef = llvm::UndefValue::get(vec->getType()); + return b_->CreateShuffleVector(vec, undef, llvm::ConstantVector::get(indices)); +} + +void CodeGenLLVM::InitTarget(const Target &target) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + switch (target.arch) { + case Target::Arch::X86: + if (target.bits == Target::Bit::k32) { + naive_vec_alignment_ = 256; + } else if (target.bits == Target::Bit::k64) { + naive_vec_alignment_ = 512; + } else { + LOG(FATAL) << "get unknown bits"; + } + break; + case Target::Arch::ARM: + naive_vec_alignment_ = 128; + break; + case Target::Arch::NVGPU: + naive_vec_alignment_ = 128; + break; + case Target::Arch::Unk: + LOG(FATAL) << "unknown Arch found"; + break; + } +} + +bool LLVM_WillVarLowerAsPointer(const std::string &var_name) { + return var_name == "_args" || utils::Endswith(var_name, "__ptr"); +} + +void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buffer, Expr index) { + // If the index is constant, generate some TBAA info that helps LLVM understand our loads/stores aren't aliased. + bool constant_index = false; + int base = 0; + int width = 1; + + if (index.defined()) { + if (const ir::Ramp *ramp = index.As()) { + auto *pstride_int = ramp->stride.As(); + auto *pbase_int = ramp->base.As(); + if (pstride_int && pbase_int) { + int stride = pstride_int->value; + base = pbase_int->value; + CHECK_GE(base, 0); + width = NextPowerOfTwo(ramp->lanes * stride); + + while (base % width) { + base -= base % width; + width *= 2; + } + constant_index = true; + } + } else { + auto *pbase_int = index.As(); + if (pbase_int) { + int pbase = pbase_int->value; + base = pbase; + constant_index = true; + } + } + } + + llvm::MDBuilder builder(b_->getContext()); + + // Add type-based-alias-analysis metadata to the pointer, so that loads and stores to different buffers can get + // reordered. + llvm::MDNode *tbaa = builder.createTBAARoot("cinn buffer"); + tbaa = builder.createTBAAScalarTypeNode(std::string(buffer), tbaa); + + // Add metadata for constant indices to allow loads and stores to the same buffer to get reordered. + if (constant_index) { + for (int w = 1024; w >= width; w /= 2) { + int b = (base / w) * w; + tbaa = builder.createTBAAScalarTypeNode(utils::StringFormat("%s.width%d.base%d", buffer.data(), w, b), tbaa); + } + } + + tbaa = builder.createTBAAStructTagNode(tbaa, tbaa, 0); + inst->setMetadata("tbaa", tbaa); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::IntrinsicOp *op) { + switch (op->getKind()) { +#define __(op__) \ + case ir::IntrinsicKind::k##op__: \ + return Visit(llvm::dyn_cast(op)); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferGetDataHandle *op) { + std::vector args({Visit(&op->buffer)}); + auto *callee = m_->getFunction("cinn_buffer_get_data_handle"); + return Call(callee, std::move(args)); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferGetDataConstHandle *op) { + std::vector args({Visit(&op->buffer)}); + auto *callee = m_->getFunction("cinn_buffer_get_data_const_handle"); + return Call(callee, std::move(args)); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BufferCreate *op) { + auto *callee = m_->getFunction(runtime::intrinsic::buffer_create_default); + auto buffer_node = op->buffer.as_buffer(); + CHECK(buffer_node); + std::vector args({ll_const_int32(buffer_node->target.runtime_arch())}); + uint64_t memory_size = (buffer_node->dtype.ElementOf().bits() + 7) / 8; + for (auto shape : buffer_node->shape) { + int shape_int = shape.as_int32(); + memory_size *= shape_int; + } + args.push_back(ll_const_int64(memory_size)); + args.push_back(ll_const_int32(32)); + + return Call(callee, args); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::GetAddr *op) { + if (auto *n = op->data.as_var()) { + return GetVar(n->name); + } else if (auto *n = op->data.as_buffer()) { + return GetVar(n->name); + } + if (auto *n = op->data.As()) { // get the address to an element in a buffer + auto *e = Visit(&op->data); + if (auto *e_load = llvm::dyn_cast(e)) { + return e_load->getPointerOperand(); + } + return e; + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::ArgsConstruct *op) { + llvm::SmallVector args; + Expr var(op->var); + var->set_type(type_of()); + var = ir::intrinsics::GetAddr::Make(var); + + llvm::Value *ll_var = Visit(&var); + var = ir::Cast::Make(type_of(), var); + + Expr num_args(static_cast(op->args.size())); + args.push_back(BitCast(ll_var, ll_cinn_pod_p_ty(), "cast_to_pod_value_t_ptr")); + args.push_back(Visit(&num_args)); + for (auto &arg : op->args) { + args.push_back(Visit(&arg)); + } + + auto *callee = m_->getFunction(runtime::intrinsic::args_construct_repr); + return Call(callee, std::move(args)); +} + +llvm::Function *CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type *ret_type, + llvm::ArrayRef arg_types) { + llvm::Module *module = m_; + + if (!llvm::Intrinsic::isOverloaded(id)) { + return llvm::Intrinsic::getDeclaration(module, id, {}); + } + + llvm::SmallVector infos; + llvm::Intrinsic::getIntrinsicInfoTableEntries(id, infos); + llvm::SmallVector overload_types; + + auto try_match = [&](llvm::FunctionType *f_ty, bool var_arg) { + overload_types.clear(); + llvm::ArrayRef ref(infos); + auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + if (llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref)) { + return llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg; + } + } + return match; + }; + + auto *fn_ty = llvm::FunctionType::get(ret_type, arg_types, false); + switch (try_match(fn_ty, false)) { + case llvm::Intrinsic::MatchIntrinsicTypes_Match: + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchRet: + return nullptr; + case llvm::Intrinsic::MatchIntrinsicTypes_NoMatchArg: + break; + } + + // try matching the var arg signature. + llvm::SmallVector var_types; + for (int i = 0; i <= arg_types.size(); ++i) { + if (i > 0) { + var_types.push_back(arg_types[i - 1]); + } + auto *ft = llvm::FunctionType::get(ret_type, var_types, true); + if (try_match(ft, true) == llvm::Intrinsic::MatchIntrinsicTypes_Match) { + return llvm::Intrinsic::getDeclaration(module, id, overload_types); + } + } + return nullptr; +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::BuiltinIntrin *op) { + std::string func_name = op->name; + if (op->id == -1) { + if (func_name == "bitwise_and") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateAnd(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_or") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateOr(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_xor") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateXor(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "bitwise_not") { + CHECK_GE(op->args.size(), 1U); + return b_->CreateNot(Visit(&op->args[0])); + } else if (func_name == "left_shift") { + CHECK_GE(op->args.size(), 2U); + return b_->CreateShl(Visit(&op->args[0]), Visit(&op->args[1])); + } else if (func_name == "right_shift") { + CHECK_GE(op->args.size(), 2U); + if (op->args[0]->type().is_int()) { + return b_->CreateAShr(Visit(&op->args[0]), Visit(&op->args[1])); + } else { + return b_->CreateLShr(Visit(&op->args[0]), Visit(&op->args[1])); + } + } else if (func_name == "isnan") { + CHECK_GE(op->args.size(), 1U); + llvm::Value *v = Visit(&op->args[0]); + return b_->CreateFCmpUNO(v, v); + } + } + + llvm::Intrinsic::ID id = op->id; + int64_t num_signature = op->arg_nums; + std::vector arg_value; + std::vector arg_type; + for (size_t i = 0; i < op->args.size(); ++i) { + arg_value.push_back(Visit(&op->args[i])); + if (i < static_cast(num_signature)) { + arg_type.push_back(arg_value.back()->getType()); + } + } + CHECK(!op->args.empty()); + llvm::Type *return_type = CinnTypeToLLVMType(op->type(), m_, true); + llvm::Function *fn = GetIntrinsicDecl(id, return_type, arg_type); + CHECK(fn) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); + return b_->CreateCall(fn, arg_value); +} + +llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) { + auto to_type = op->GetOutputType(0); + llvm::Function *callee{}; + + if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_float); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_double); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_bfloat16); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_float16); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_bool); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int8); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int16); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int32); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_int64); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint8); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint16); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint32); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_uint64); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_void_p); + } else if (to_type == type_of()) { + callee = m_->getFunction(runtime::intrinsic::pod_value_to_buffer_p); + } else { + LOG(FATAL) << "Not supported type: " << to_type; + } + + CHECK(callee); + auto *value = Visit(&op->pod_value_ptr); + CHECK(value); + return Call(callee, std::vector({value}), "pod_value_cast"); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/codegen_llvm.h b/paddle/cinn/backends/llvm/codegen_llvm.h new file mode 100644 index 0000000000000..f472e2239e15d --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_llvm.h @@ -0,0 +1,248 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/backends/llvm/ir_builder_mixin.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/module.h" + +namespace cinn { +namespace backends { + +class LLVMIRVisitor : public ir::IRVisitorBase { + public: + LLVMIRVisitor() = default; + + using ir::IRVisitorBase::Visit; +#define __m(t__) virtual llvm::Value *Visit(const ir::t__ *x) = 0; + NODETY_FORALL(__m) +#undef __m +}; + +/** + * Tell whether a variable called \p \var_name will lowered to a pointer type in LLVM. + * @param var_name name of the variable. + * @return a boolean. + */ +bool LLVM_WillVarLowerAsPointer(const std::string &var_name); + +class SymbolTable { + public: + SymbolTable() = default; + + void PushScope() { scopes_.emplace_back(); } + + llvm::Value *Lookup(const std::string &id) { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); it++) { + auto vt = (*it).find(id); + if (vt != (*it).end()) return vt->second; + } + return nullptr; + } + + void Insert(const std::string &id, llvm::Value *value) { + CHECK(!scopes_.empty()); + scopes_.back().emplace(id, value); + } + + void Erase(const std::string &id) { + CHECK(!scopes_.empty()); + scopes_.back().erase(id); + } + + void PopScope() { + CHECK(!scopes_.empty()); + scopes_.pop_back(); + } + + //! Get the number of the variables contained in the current scope. + size_t size() const { return scopes_.empty() ? 0 : scopes_.back().size(); } + + size_t num_scopes() const { return scopes_.size(); } + + private: + std::vector> scopes_; + + SymbolTable(const SymbolTable &) = delete; +}; + +struct SymbolTableGuard { + explicit SymbolTableGuard(SymbolTable &symbol_table) : symbol_table_(symbol_table) { symbol_table.PushScope(); } + + ~SymbolTableGuard() { symbol_table_.PopScope(); } + + private: + SymbolTable &symbol_table_; +}; + +/** + * Base class of all the LLVM-based codegen. + */ +class CodeGenLLVM : public LLVMIRVisitor, public IrBuilderMixin { + public: + explicit CodeGenLLVM(llvm::Module *m, + llvm::IRBuilder<> *b, + const std::shared_ptr &symbol_table = nullptr, + const Target &target = common::DefaultHostTarget()); + + // Common llvm types + // @{ + inline llvm::Type *ll_void_p_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_void_pp_ty() const { return llvm_type_of(m_); } + + inline llvm::Type *ll_int8_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_int16_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_int32_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_int64_ty() const { return llvm_type_of(m_); } + + inline llvm::Type *ll_uint8_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_uint16_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_uint32_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_uint64_ty() const { return llvm_type_of(m_); } + + inline llvm::Type *ll_bf16_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_fp16_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_fp32_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_fp64_ty() const { return llvm_type_of(m_); } + + inline llvm::Type *ll_cinn_buffer_p_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_cinn_pod_ty() const { return llvm_type_of(m_); } + inline llvm::Type *ll_cinn_pod_p_ty() const { return llvm_type_of(m_); } + // @} + + //! get a llvm type equivalent to a CINN type. + inline llvm::Type *ll_type_of(Type type) { return CinnTypeToLLVMType(type, m_); } + + // Common methods to get a constant + // @{ + inline llvm::Constant *ll_const_int32(int v) const { return llvm::ConstantInt::get(b_->getInt32Ty(), v); } + inline llvm::Constant *ll_const_int64(int v) const { return llvm::ConstantInt::get(b_->getInt64Ty(), v); } + // @} + + //! Get the bound LLVM module. + llvm::Module *m() { return m_; } + //! Get the bound LLVM ir builder. + llvm::IRBuilder<> *b() { return b_; } + + void Compile(const ir::Module &module); + + using LLVMIRVisitor::Visit; + +#define __(op__) llvm::Value *Visit(const ir::op__ *) override; + NODETY_FORALL(__) +#undef __ + +#define __(op__) llvm::Value *Visit(const ir::intrinsics::op__ *); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + + //! Used for the ExternFuncEmitter to store temporary result. + mutable llvm::Value *extern_func_emit_res_{}; + + std::shared_ptr named_vars() { return symbol_table_; } + + llvm::FunctionType *GenFunctionTypeFromCinnFunction(const ir::_LoweredFunc_ *func, bool with_buffer_type); + + virtual llvm::Value *GetVar(const std::string &name, bool lazy = true); + + llvm::Function *GetIntrinsicDecl(llvm::Intrinsic::ID id, + llvm::Type *ret_type, + llvm::ArrayRef arg_types); + + // Constants + // @{ + inline llvm::Value *llvm_int32_constant(int v) { return llvm::ConstantInt::get(ll_int32_ty(), v); } + // @} + + virtual ~CodeGenLLVM(); + + protected: + // TODO(Superjomn) When to clear the existing local variables when switch to another function? + llvm::Value *SetVar(const std::string &name, llvm::Value *val); + llvm::Value *EmitVectorSlice(llvm::Value *vec, int begin, int extent); + llvm::Value *EmitVectorPad(llvm::Value *vec, int lanes); + llvm::Value *EmitVectorConcat(std::vector vecs); + + //! Visit different kinds of Calls, the following methods are analogous to + //! those in CodeGenC. + // @{ + llvm::Value *EmitCall_buffer_create(const ir::Call *op); + llvm::Value *EmitCall_buffer_malloc(const ir::Call *op); + llvm::Value *EmitCall_get_address(const ir::Call *op); + llvm::Value *EmitCall_debug_info(const ir::Call *op); + // @} + + llvm::Value *EmitBinaryOp(llvm::Value *lhs, llvm::Value *rhs, char opcode, bool is_integral, bool is_signed = true); + + llvm::Value *LLVMGenGlobalStringVar(const std::string &data); + + llvm::Value *CreateBufferPtr(Type t, llvm::Value *buffer, llvm::Value *index); + llvm::Value *CreateBufferVecPtr(Type t, llvm::Value *buffer, llvm::Value *index); + llvm::Value *CreateVecSlice(llvm::Value *vec, int begin, int lanes); + + llvm::Value *DenseVectorLoad(const ir::Load *load); + llvm::Value *CreateSerialFor(const ir::For *op, int stride = 1); + + /** + * Mark a load or store with type-based-alias-analysis metadata so that LLVM can optimize by reordering loads and + * stores across different buffers. + */ + void AddTbaaMetadata(llvm::Instruction *inst, absl::string_view buffer, Expr index); + + void InitTarget(const Target &target); + + void Scalarize(const Expr &e, std::function flambda); + + llvm::Module *m_; + llvm::IRBuilder<> *b_; + // Current function + llvm::Function *f_; + + std::unique_ptr md_builder_; + + // std::shared_ptr> named_vars_; + std::shared_ptr symbol_table_; + std::unordered_set alias_vars_; + + llvm::MDNode *md_tbaa_root_{nullptr}; + llvm::MDNode *md_tbaa_alias_set_{nullptr}; + + int naive_vec_alignment_{0}; + Target target_; +}; +namespace detail { +Expr StridedRampBase(Expr e, int stride); +} // namespace detail + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/codegen_llvm_test.cc b/paddle/cinn/backends/llvm/codegen_llvm_test.cc new file mode 100644 index 0000000000000..ebeaf20f01577 --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_llvm_test.cc @@ -0,0 +1,623 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/codegen_llvm.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "cinn/backends/llvm/cinn_runtime_llvm_ir.h" +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/module.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" + +namespace cinn { +namespace backends { + +namespace { + +auto CreateCodeGenLLVMTestLLVM() { + auto context = std::make_unique(); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_codegen_llvm", *context); + auto emitter = std::make_unique(m.get(), b.get()); + + return std::make_tuple(std::move(m), std::move(b), std::move(context), std::move(emitter)); +} + +auto CreateTensor() { + ir::Expr M(3); + ir::Expr N(2); + lang::Placeholder a("a", {M, N}); + lang::Placeholder b("b", {M, N}); + auto c = lang::Compute( + {M, N}, [&](auto i, auto j) { return a(i, j) + b(i, j); }, "c"); + + lang::Buffer c_buf(common::Float(32)); + + return std::make_tuple(std::move(a), std::move(b), std::move(c), std::move(c_buf)); +} + +auto CreateLLVMType(llvm::LLVMContext *context) { + llvm::Type *i8 = llvm::Type::getInt8Ty(*context); + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *i64 = llvm::Type::getInt64Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); + + return std::make_tuple(i8, i32, i64, u32, f32, f16, bf16); +} + +template +auto CreateBinaryOp(common::Type t, T1 x, T2 y) { + auto px = std::make_unique(t, x); + auto py = std::make_unique(t, y); + + auto ex = ir::Expr(px.release()); + auto ey = ir::Expr(py.release()); + + return std::make_unique(std::move(ex), std::move(ey)); +} + +auto CreateIrBuffer(common::Type t, std::string name, std::vector shape, int data_alignment = 0) { + CHECK_GE(data_alignment, 0); + auto buffer = ir::_Buffer_::Make(std::move(name), std::move(t)); + + if (data_alignment) { + buffer->data_alignment = data_alignment; + } + + for (auto i : shape) { + auto pi = std::make_unique(common::Int(32), i); + buffer->shape.emplace_back(pi.release()); + } + + return buffer; +} + +auto CreateIrTensor(std::string name, std::vector shape) { + std::vector shape_expr; + for (auto i : shape) { + auto pi = std::make_unique(common::Int(32), i); + shape_expr.emplace_back(pi.release()); + } + + ir::Tensor tensor(std::move(name), Float(32), shape_expr, shape_expr, {}, {}); + tensor->domain = tensor->shape; + return tensor; +} + +auto CreateLoweredFunc() { + // +} + +} // namespace + +using cinn::common::bfloat16; +using cinn::common::float16; + +TEST(CodeGenLLVM, Imm) { + auto context = std::make_unique(); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_codegen_llvm", *context); + auto emitter = std::make_unique(m.get(), b.get()); + + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); + + llvm::Value *value = nullptr; + + ir::IntImm i32_imm(common::Int(32), 10); + value = emitter->Visit(&i32_imm); + ASSERT_EQ(value->getType(), i32); + ASSERT_EQ(value, llvm::ConstantInt::get(i32, i32_imm.value, true)); + // value->print(llvm::outs(), false); + + ir::UIntImm u32_imm(common::UInt(32), 5); + value = emitter->Visit(&u32_imm); + ASSERT_EQ(value->getType(), u32); + ASSERT_EQ(value, llvm::ConstantInt::get(u32, u32_imm.value, false)); + + ir::FloatImm float32_imm(common::Float(32), 2.5); + value = emitter->Visit(&float32_imm); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, llvm::ConstantFP::get(f32, float32_imm.value)); + + ir::FloatImm float16_imm(common::Float16(), 2.5); + value = emitter->Visit(&float16_imm); + ASSERT_EQ(value->getType(), f16); + ASSERT_EQ(value, llvm::ConstantFP::get(f16, float16_imm.value)); + + ir::FloatImm bfloat16_imm(common::BFloat16(), 2.5); + value = emitter->Visit(&bfloat16_imm); + ASSERT_EQ(value->getType(), bf16); + ASSERT_EQ(value, llvm::ConstantFP::get(bf16, bfloat16_imm.value)); +} + +TEST(CodeGenLLVM, Expr) { + auto context = std::make_unique(); + auto b = std::make_unique>(*context); + auto m = std::make_unique("test_binary_op", *context); + auto emitter = std::make_unique(m.get(), b.get()); + + llvm::Type *i1 = llvm::Type::getInt1Ty(*context); + llvm::Type *i8 = llvm::Type::getInt8Ty(*context); + llvm::Type *i32 = llvm::Type::getInt32Ty(*context); + llvm::Type *i64 = llvm::Type::getInt64Ty(*context); + llvm::Type *u32 = llvm::Type::getInt32Ty(*context); + llvm::Type *f32 = llvm::Type::getFloatTy(*context); + llvm::Type *f16 = llvm::Type::getHalfTy(*context); + llvm::Type *bf16 = llvm::Type::getBFloatTy(*context); + + llvm::Value *value = nullptr; + llvm::Value *expect_value = nullptr; + + std::string outs; + llvm::raw_string_ostream ss(outs); + + // + + do { + int x = 2; + int y = 3; + auto op = CreateBinaryOp(common::Int(32), x, y); + + expect_value = llvm::ConstantInt::get(i32, x + y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), i32); + ASSERT_EQ(value, expect_value); + // value->print(llvm::outs(), false); + // value->print(ss, false); + // LOG(INFO) << "xxx: " << ss.str(); + } while (false); + + // - + do { + float x = 2.5; + float y = 3.5; + auto op = CreateBinaryOp(common::Float(32), x, y); + + expect_value = llvm::ConstantFP::get(f32, x - y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + } while (false); + + // - + do { + float16 x{2.5}; + float16 y{3.5}; + auto op = CreateBinaryOp(common::Float16(), x, y); + + expect_value = llvm::ConstantFP::get(f16, x - y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), f16); + ASSERT_EQ(value, expect_value); + } while (false); + + // - + do { + bfloat16 x{2.5}; + bfloat16 y{3.5}; + auto op = CreateBinaryOp(common::BFloat16(), x, y); + + expect_value = llvm::ConstantFP::get(bf16, x - y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), bf16); + ASSERT_EQ(value, expect_value); + } while (false); + + // * + do { + int x = 5; + int y = 3; + auto op = CreateBinaryOp(common::Int(64), x, y); + expect_value = llvm::ConstantInt::get(i64, x * y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), i64); + ASSERT_EQ(value, expect_value); + } while (false); + + // / + do { + float x = 6; + float y = 4; + auto op = CreateBinaryOp(common::Float(32), x, y); + expect_value = llvm::ConstantFP::get(f32, x / y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + } while (false); + + // / + do { + float16 x{6}; + float16 y{4}; + auto op = CreateBinaryOp(common::Float16(), x, y); + expect_value = llvm::ConstantFP::get(f16, x / y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), f16); + ASSERT_EQ(value, expect_value); + } while (false); + + // / + do { + bfloat16 x{6}; + bfloat16 y{4}; + auto op = CreateBinaryOp(common::BFloat16(), x, y); + expect_value = llvm::ConstantFP::get(bf16, x / y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), bf16); + ASSERT_EQ(value, expect_value); + } while (false); + + // % + do { + int x = 25; + int y = 7; + auto op = CreateBinaryOp(common::Int(32), x, y); + expect_value = llvm::ConstantInt::get(i32, x % y); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), i32); + ASSERT_EQ(value, expect_value); + } while (false); + + // == + do { + int x = 3; + int y = 3; + auto op = CreateBinaryOp(common::Int(32), x, y); + expect_value = llvm::ConstantInt::get(i1, 1); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // != + do { + float x = 3; + float y = 3; + + auto op = CreateBinaryOp(common::Float(32), x, y); + expect_value = llvm::ConstantInt::get(i1, 0); + value = emitter->Visit(op.get()); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // < + do { + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantInt::get(i1, 0); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // <= + do { + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantInt::get(i1, 1); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // > + do { + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantInt::get(i1, 0); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // >= + do { + int x = 6; + int y = 6; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantInt::get(i1, 1); + ASSERT_EQ(value->getType(), i1); + ASSERT_EQ(value, expect_value); + } while (false); + + // and, or + do { + } while (false); + + // min + do { + int x = 2; + int y = 3; + auto op = CreateBinaryOp(common::Int(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantInt::get(i32, std::min(x, y)); + ASSERT_EQ(value->getType(), i32); + ASSERT_EQ(value, expect_value); + } while (false); + + // max + do { + float x = 2; + float y = 3; + auto op = CreateBinaryOp(common::Float(32), x, y); + value = emitter->Visit(op.get()); + expect_value = llvm::ConstantFP::get(f32, std::max(x, y)); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + } while (false); + + // minus + // not + + // cast + do { + // i32 -> u32 + // skip + + // i32 -> f32 + LOG(INFO) << "test i32 -> f32"; + int v2 = 2; + auto x2 = std::make_unique(common::Int(32), v2); + auto ex2 = ir::Expr(x2.release()); + auto op2 = ir::Cast::Make(common::Float(32), std::move(ex2)); + value = emitter->Visit(&op2); + expect_value = llvm::ConstantFP::get(f32, v2); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + + // f32 -> i32 + LOG(INFO) << "test f32 -> i32"; + float v3 = 3; + auto x3 = std::make_unique(common::Float(32), v3); + auto ex3 = ir::Expr(x3.release()); + auto op3 = ir::Cast::Make(common::Int(32), std::move(ex3)); + value = emitter->Visit(&op3); + expect_value = llvm::ConstantInt::get(i32, v3); + ASSERT_EQ(value->getType(), i32); + ASSERT_EQ(value, expect_value); + + // i32 -> f16 + LOG(INFO) << "test i32 -> f16"; + int v4 = 4; + auto x4 = std::make_unique(common::Int(32), v4); + auto ex4 = ir::Expr(x4.release()); + auto op4 = ir::Cast::Make(common::Float16(), std::move(ex4)); + value = emitter->Visit(&op4); + expect_value = llvm::ConstantFP::get(f16, v4); + ASSERT_EQ(value->getType(), f16); + ASSERT_EQ(value, expect_value); + + // f16 -> f32 + LOG(INFO) << "test f16 -> f32"; + float16 v5{5}; + auto x5 = std::make_unique(common::Float16(), v5); + auto ex5 = ir::Expr(x5.release()); + auto op5 = ir::Cast::Make(common::Float(32), std::move(ex5)); + value = emitter->Visit(&op5); + expect_value = llvm::ConstantFP::get(f32, v5); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + + // i32 -> bf16 + LOG(INFO) << "test i32 -> bf16"; + int v6 = 4; + auto x6 = std::make_unique(common::Int(32), v6); + auto ex6 = ir::Expr(x6.release()); + auto op6 = ir::Cast::Make(common::BFloat16(), std::move(ex6)); + value = emitter->Visit(&op6); + expect_value = llvm::ConstantFP::get(bf16, v6); + ASSERT_EQ(value->getType(), bf16); + ASSERT_EQ(value, expect_value); + + // bf16 -> f32 + LOG(INFO) << "test bf16 -> f32"; + bfloat16 v7{5}; + auto x7 = std::make_unique(common::BFloat16(), v7); + auto ex7 = ir::Expr(x7.release()); + auto op7 = ir::Cast::Make(common::Float(32), std::move(ex7)); + value = emitter->Visit(&op7); + expect_value = llvm::ConstantFP::get(f32, v7); + ASSERT_EQ(value->getType(), f32); + ASSERT_EQ(value, expect_value); + } while (false); +} + +TEST(CodeGenLLVM, Statement) { + return; + std::string outs; + llvm::raw_string_ostream ss(outs); + + do { + auto _m_b_context_emitter_ = CreateCodeGenLLVMTestLLVM(); // NOLINT + auto &m = std::get<0>(_m_b_context_emitter_); + auto &b = std::get<1>(_m_b_context_emitter_); + auto &context = std::get<2>(_m_b_context_emitter_); + auto &emitter = std::get<3>(_m_b_context_emitter_); + auto _i8_i32_i64_u32_f32_f16_ = CreateLLVMType(context.get()); // NOLINT + auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); + auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); + auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); + auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); + auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + auto &f16 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + llvm::FunctionType *function_type = llvm::FunctionType::get(i32, {}, false); + llvm::Function *function = llvm::Function::Create( + function_type, llvm::Function::ExternalLinkage, "codegen_llvm_test.Alloc_Store_Load_Free", m.get()); + + std::string module_str; + module_str += "; ModuleID = 'test_codegen_llvm'"; + module_str += "\nsource_filename = \"test_codegen_llvm\"\n"; + module_str += "\ndefine i32 @codegen_llvm_test.Alloc_Store_Load_Free()"; + + llvm::BasicBlock *entry = llvm::BasicBlock::Create(*context, "entry", function); + b->SetInsertPoint(entry); + + module_str += " {\nentry:"; + + // ir::Tensor + auto tensor_op = CreateIrTensor("x", {2, 3}); + tensor_op->buffer = CreateIrBuffer(common::Int(32), "", {2, 3}); + + // ir::Alloc + auto alloc_op = std::make_unique(); + alloc_op->destination = ir::Expr(tensor_op->buffer); + + // ir::Store + auto store_op = std::make_unique(); + store_op->tensor = ir::Expr(tensor_op); + for (int i : {1, 1}) { + auto pi = std::make_unique(common::Int(32), std::move(i)); + store_op->indices.emplace_back(pi.release()); + } + auto store_value = std::make_unique(common::Int(32), 5); + store_op->value = ir::Expr(store_value.release()); + + // ir::Load + auto load_op = std::make_unique(); + load_op->tensor = ir::Expr(tensor_op); + for (int i : {1, 1}) { + auto pi = std::make_unique(common::Int(32), std::move(i)); + load_op->indices.emplace_back(pi.release()); + } + + // ir::Free + auto free_op = std::make_unique(); + free_op->destination = ir::Expr(tensor_op->buffer); + + // ir::Call + auto call_op = std::make_unique(common::Int(32)); + call_op->name = "codegen_llvm_test.Alloc_Store_Load_Free"; + + // Emit llvm ir + auto *alloc_inst = llvm::dyn_cast(emitter->Visit(alloc_op.get())); + module_str += "\n %0 = alloca [6 x i32]"; + auto *store_inst = llvm::dyn_cast(emitter->Visit(store_op.get())); + module_str += "\n %1 = getelementptr [6 x i32], [6 x i32]* %0, i32 1"; + module_str += "\n store i32 5, [6 x i32]* %1"; + auto *load_inst = llvm::dyn_cast(emitter->Visit(load_op.get())); + module_str += "\n %2 = getelementptr [6 x i32], [6 x i32]* %0, i32 1"; + module_str += "\n %3 = load [6 x i32], [6 x i32]* %2"; + + b->CreateRet(llvm::ConstantInt::get(i32, 1)); + + module_str += "\n ret i32 1"; + module_str += "\n}\n"; + + auto log_inst = [&ss, &outs](auto *inst) { + inst->print(ss, false); + LOG(INFO) << inst->getOpcodeName() << " instruction:" << ss.str(); + outs.clear(); + }; + + log_inst(alloc_inst); + log_inst(store_inst); + log_inst(load_inst); + + ASSERT_EQ(module_str, ss.str()); + } while (false); +} + +TEST(CodeGenLLVM, LowerFunc) { + std::string outs; + llvm::raw_string_ostream ss(outs); + + do { + auto context = std::make_unique(); + // auto src_name = m->getSourceFileName(); + llvm::SMDiagnostic error; + std::string runtime_ir(backends::kRuntimeLlvmIr); + // NOTE: read ir string before IRBuilder create + auto m = llvm::parseAssemblyString(runtime_ir, error, *context); + error.print("error:", ss, false); + CHECK(m) << ss.str(); + auto b = std::make_unique>(*context); + + auto emitter = std::make_unique(m.get(), b.get()); + + auto _i8_i32_i64_u32_f32_f16_ = CreateLLVMType(context.get()); // NOLINT + auto &i8 = std::get<0>(_i8_i32_i64_u32_f32_f16_); + auto &i32 = std::get<1>(_i8_i32_i64_u32_f32_f16_); + auto &i64 = std::get<2>(_i8_i32_i64_u32_f32_f16_); + auto &u32 = std::get<3>(_i8_i32_i64_u32_f32_f16_); + auto &f32 = std::get<4>(_i8_i32_i64_u32_f32_f16_); + auto &f16 = std::get<5>(_i8_i32_i64_u32_f32_f16_); + auto _x_y_z_z_buf_ = CreateTensor(); // NOLINT + auto &x = std::get<0>(_x_y_z_z_buf_); + auto &y = std::get<1>(_x_y_z_z_buf_); + auto &z = std::get<2>(_x_y_z_z_buf_); + auto &z_buf = std::get<3>(_x_y_z_z_buf_); + + z->Bind(z_buf); + + auto stages = CreateStages({x, y, z}); + auto function = lang::Lower("add1", stages, {x, y, z}); + ir::Expr func_expr(function); + + auto ir_function = emitter->Visit(&func_expr); + LOG(INFO) << "ir function: " << func_expr; + + auto func = m->getFunction("add1"); + } while (false); +} + +TEST(SymbolTable, test) { + SymbolTable table; + ASSERT_EQ(table.num_scopes(), 0UL); + + table.PushScope(); + + auto *fake_addr = reinterpret_cast(1); + table.Insert("a", fake_addr); + ASSERT_EQ(table.size(), 1UL); + + table.PushScope(); + table.Insert("b", fake_addr); + ASSERT_EQ(table.size(), 1UL); + + auto *a = table.Lookup("a"); + ASSERT_EQ(a, fake_addr); + + auto *b = table.Lookup("b"); + ASSERT_EQ(b, fake_addr); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/codegen_x86.cc b/paddle/cinn/backends/llvm/codegen_x86.cc new file mode 100644 index 0000000000000..c76b04b16c372 --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_x86.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/codegen_x86.h" + +#include +#include + +#include +#include + +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/common/target.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/optim/collect_undefined_vars.h" +#include "cinn/runtime/intrinsic.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/Support/Casting.h" + +namespace cinn::backends { + +CodeGenX86::CodeGenX86(llvm::Module* m, llvm::IRBuilder<>* b, const std::shared_ptr& vars) + : CodeGenLLVM(m, b, vars) {} + +CodeGenX86::~CodeGenX86() {} + +llvm::Value* CodeGenX86::PackVars(const std::vector& vars, uint64_t* num_bytes) { + if (vars.empty()) { + *num_bytes = 0U; + return llvm::Constant::getNullValue(ll_void_p_ty()); + } + std::vector types; + for (auto& v : vars) { + types.push_back(GetVar(v, false)->getType()); + } + llvm::StructType* t_data = llvm::StructType::create(types); + llvm::Value* data = b_->CreateAlloca(t_data, llvm_int32_constant(1)); + for (size_t i = 0; i < vars.size(); ++i) { + b_->CreateStore(GetVar(vars[i]), b_->CreateInBoundsGEP(data, {llvm_int32_constant(0), llvm_int32_constant(i)})); + } + *num_bytes = m_->getDataLayout().getTypeAllocSize(llvm::cast(data->getType())->getElementType()); + return data; +} + +void CodeGenX86::UnpackVars(const std::vector& vars, llvm::Value* data) { + for (size_t i = 0; i < vars.size(); ++i) { + SetVar(vars[i], b_->CreateLoad(b_->CreateInBoundsGEP(data, {llvm_int32_constant(0), llvm_int32_constant(i)}))); + } +} + +llvm::BasicBlock* CodeGenX86::CheckCallSuccess(llvm::Value* retcode) { + llvm::BasicBlock* fail_block = + llvm::BasicBlock::Create(b_->getContext(), "call_fail", b_->GetInsertBlock()->getParent(), nullptr); + llvm::BasicBlock* end_block = + llvm::BasicBlock::Create(b_->getContext(), "call_end", b_->GetInsertBlock()->getParent(), nullptr); + llvm::Value* succ = b_->CreateICmpEQ(retcode, llvm::ConstantInt::get(ll_int32_ty(), 0)); + b_->CreateCondBr(succ, end_block, fail_block); + b_->SetInsertPoint(fail_block); + RetVoid(); + b_->SetInsertPoint(end_block); + return end_block; +} + +void CodeGenX86::CreateParallelLaunch(Expr body, int num_task) { + auto ftype_parallel_lambda = + llvm::FunctionType::get(ll_int32_ty(), {ll_int32_ty(), ll_int32_ty(), ll_type_of(Float(32).PointerOf())}, false); + llvm::Function* f = + llvm::Function::Create(ftype_parallel_lambda, llvm::Function::PrivateLinkage, "__parallel_lambda", m_); + std::vector vars = optim::CollectUndefinedVars(&body); + uint64_t nbytes; + auto* data = PackVars(vars, &nbytes); + + auto ftype_parallel_launch = llvm::FunctionType::get( + ll_int32_ty(), {ftype_parallel_lambda->getPointerTo(), ll_type_of(Float(32).PointerOf()), ll_int32_ty()}, false); + auto* launch_callee = llvm::dyn_cast( + m_->getOrInsertFunction(runtime::intrinsic::parallel_launch, ftype_parallel_launch).getCallee()); + launch_callee->setCallingConv(llvm::CallingConv::C); + auto* launch_end = CheckCallSuccess(b_->CreateCall( + launch_callee, + {f, b_->CreatePointerCast(data, ll_type_of(Float(32).PointerOf())), llvm_int32_constant(num_task)})); + + auto* flambda = llvm::BasicBlock::Create(b_->getContext(), "flambda", f); + b_->SetInsertPoint(flambda); + auto it = f->arg_begin(); + auto* task_id = &(*it++); + auto* penv = &(*it++); + data = b_->CreatePointerCast(&(*it++), data->getType()); + symbol_table_->PushScope(); + UnpackVars(vars, data); + ParallelEnv par_env; + auto task_id_name = common::UniqName("task_id"); + auto num_task_name = common::UniqName("num_task"); + par_env.task_id = ir::Var(task_id_name, Int(32)); + par_env.num_task = ir::Var(num_task_name, Int(32)); + SetVar(task_id_name, task_id); + SetVar(num_task_name, penv); + par_env.penv = penv; + std::swap(f_, f); + std::swap(parallel_env_, par_env); + this->Visit(&body); + b_->CreateRet(ll_const_int32(0)); + symbol_table_->Erase(task_id_name); + symbol_table_->Erase(num_task_name); + symbol_table_->PopScope(); + std::swap(parallel_env_, par_env); + std::swap(f_, f); + CHECK_NE(par_env.parallel_loop_count, 0) << "find no parallel loop within parallel launch"; + b_->SetInsertPoint(launch_end); +} + +llvm::Value* CodeGenX86::Visit(const ir::For* op) { + if (op->is_parallel()) { + VLOG(3) << "parallel forloop"; + if (parallel_env_.penv == nullptr) { + CreateParallelLaunch( + ir::For::Make( + op->loop_var, op->min, op->extent, op->for_type(), op->device_api, op->body, op->vectorize_info()), + 0); + } else { + Expr num_task = parallel_env_.num_task; + Expr task_id = parallel_env_.task_id; + CHECK(!parallel_env_.in_parallel_loop) << "Nested parallel loop is not supported, try to fuse them instead"; + parallel_env_.in_parallel_loop = true; + if (parallel_env_.stride_pattern) { + auto new_for = ir::For::Make( + op->loop_var, task_id, op->extent, op->for_type(), op->device_api, op->body, op->vectorize_info()); + auto for_node = new_for.As(); + CHECK(for_node); + CreateSerialFor(for_node, num_task.as_int32()); + } else { + Expr extent = op->extent; + Expr step = (extent + num_task - Expr(1)) / num_task; + Expr begin = min(task_id * step, op->extent); + Expr end = min((task_id + Expr(1)) * step, op->extent); + auto new_for = + ir::For::Make(op->loop_var, begin, end, op->for_type(), op->device_api, op->body, op->vectorize_info()); + auto for_node = new_for.As(); + CHECK(for_node); + CreateSerialFor(for_node); + } + parallel_env_.in_parallel_loop = false; + ++parallel_env_.parallel_loop_count; + } + } else { + return CodeGenLLVM::Visit(op); + } + return nullptr; +} +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/codegen_x86.h b/paddle/cinn/backends/llvm/codegen_x86.h new file mode 100644 index 0000000000000..baf480f51a3d5 --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_x86.h @@ -0,0 +1,59 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include + +#include "cinn/backends/llvm/codegen_llvm.h" + +namespace cinn::backends { + +class CodeGenX86 : public CodeGenLLVM { + public: + explicit CodeGenX86(llvm::Module* m, llvm::IRBuilder<>* b, const std::shared_ptr& vars = nullptr); + virtual ~CodeGenX86(); + + using LLVMIRVisitor::Visit; + + llvm::Value* Visit(const ir::For* op); + + private: + // parallel information + struct ParallelEnv { + Expr task_id; + Expr num_task; + bool stride_pattern{false}; + bool in_parallel_loop{false}; + int parallel_loop_count{0}; + llvm::Value* penv{nullptr}; + }; + + llvm::Value* ParallelLaunch(); + // Create parallel launch + void CreateParallelLaunch(Expr body, int num_task); + + llvm::Value* PackVars(const std::vector& vars, uint64_t* num_bytes); + void UnpackVars(const std::vector& vars, llvm::Value* data); + llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); + // Current parallel environment scope. + ParallelEnv parallel_env_; +}; + +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/codegen_x86_test.cc b/paddle/cinn/backends/llvm/codegen_x86_test.cc new file mode 100644 index 0000000000000..95ded4776ce56 --- /dev/null +++ b/paddle/cinn/backends/llvm/codegen_x86_test.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/codegen_x86.h" + +#include + +#include "cinn/backends/llvm/simple_jit.h" +#include "cinn/cinn.h" +#include "cinn/common/test_helper.h" +#include "cinn/runtime/cinn_runtime.h" + +namespace cinn { +namespace backends { + +TEST(Vectorize, basic) { + Expr M(1024); + Placeholder A("A", {M}); + Placeholder B("B", {M}); + + auto C = Compute( + {M}, [&](Expr i) { return A(i) + B(i); }, "C"); + auto stages = CreateStages({C}); + + stages[C]->Vectorize(0, 8); + + auto fn = Lower("fn", stages, {A, B, C}); + + LOG(INFO) << "fn: " << fn; + + Module::Builder builder("module", common::DefaultHostTarget()); + builder.AddFunction(fn); + + auto module = builder.Build(); + + LOG(INFO) << "\n" << module->functions[0]; + + auto jit = SimpleJIT::Create(); + jit->Link(builder.Build()); + + auto fn_ = jit->Lookup("fn"); + + auto* fn_ptr = reinterpret_cast(fn_); + + auto* A_buf = common::BufferBuilder(Float(32), {1024}).set_random().set_align(64).Build(); + auto* B_buf = common::BufferBuilder(Float(32), {1024}).set_random().set_align(64).Build(); + auto* C_buf = common::BufferBuilder(Float(32), {1024}).set_zero().set_align(64).Build(); + + auto args = common::ArgsBuilder().Add(A_buf).Add(B_buf).Add(C_buf).Build(); + + fn_ptr(reinterpret_cast(args.data()), args.size()); + + auto* A_data = reinterpret_cast(A_buf->memory); + auto* B_data = reinterpret_cast(B_buf->memory); + auto* C_data = reinterpret_cast(C_buf->memory); + for (int i = 0; i < C_buf->num_elements(); i++) { + ASSERT_NEAR(A_data[i] + B_data[i], C_data[i], 1e-5); + } +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/execution_engine.cc b/paddle/cinn/backends/llvm/execution_engine.cc new file mode 100644 index 0000000000000..175e58dbdd59b --- /dev/null +++ b/paddle/cinn/backends/llvm/execution_engine.cc @@ -0,0 +1,250 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/execution_engine.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include // NOLINT +#include +#include + +#include "cinn/backends/codegen_cuda_host.h" +#include "cinn/backends/llvm/cinn_runtime_llvm_ir.h" +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/codegen_x86.h" +#include "cinn/backends/llvm/llvm_optimizer.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/profiler.h" + +namespace cinn::backends { +namespace { +void InitializeLLVMPasses() { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + auto ®istry = *llvm::PassRegistry::getPassRegistry(); + llvm::initializeCore(registry); + llvm::initializeTransformUtils(registry); + llvm::initializeScalarOpts(registry); + llvm::initializeIPO(registry); + llvm::initializeInstCombine(registry); + llvm::initializeAggressiveInstCombine(registry); + llvm::initializeAnalysis(registry); + llvm::initializeVectorization(registry); + llvm::initializeSROALegacyPassPass(registry); + + // llvm::initializeCodeGen(registry); + // llvm::initializeTarget(registry); + // llvm::initializeCodeGenPreparePass(registry); +} +} // namespace +void NaiveObjectCache::notifyObjectCompiled(const llvm::Module *m, llvm::MemoryBufferRef obj_buffer) { + cached_objects_[m->getModuleIdentifier()] = + llvm::MemoryBuffer::getMemBufferCopy(obj_buffer.getBuffer(), obj_buffer.getBufferIdentifier()); +} + +std::unique_ptr NaiveObjectCache::getObject(const llvm::Module *m) { + auto it = cached_objects_.find(m->getModuleIdentifier()); + if (it == cached_objects_.end()) { + VLOG(1) << "No object for " << m->getModuleIdentifier() << " in cache. Compiling."; + return nullptr; + } + + VLOG(3) << "Object for " << m->getModuleIdentifier() << " loaded from cache."; + return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef()); +} + +/*static*/ std::unique_ptr ExecutionEngine::Create(const ExecutionOptions &config) { + return Create(config, {}); +} + +/*static*/ std::unique_ptr ExecutionEngine::Create(const ExecutionOptions &config, + RuntimeSymbols &&module_symbols) { + VLOG(1) << "===================== Create CINN ExecutionEngine begin ===================="; + VLOG(1) << "initialize llvm config"; + VLOG(1) << "llvm version: " << LLVM_VERSION_STRING; + VLOG(1) << "llvm default target triple: " << LLVM_DEFAULT_TARGET_TRIPLE; + + static std::once_flag flag; + std::call_once(flag, InitializeLLVMPasses); + + auto engine = std::make_unique(/*enable_object_cache=*/true, std::move(module_symbols)); + + auto compile_layer_creator = [&engine](llvm::orc::JITTargetMachineBuilder jtmb) + -> llvm::Expected> { + auto machine = llvm::cantFail(jtmb.createTargetMachine()); + VLOG(1) << "create llvm compile layer"; + VLOG(1) << "Target Name: " << machine->getTarget().getName(); + VLOG(1) << "Target CPU: " << machine->getTargetCPU().str() << std::endl; + return std::make_unique(std::move(machine), engine->cache_.get()); + }; + + auto object_layer_creator = [&](llvm::orc::ExecutionSession &session, const llvm::Triple &triple) { + auto object_layer = std::make_unique( + session, []() { return std::make_unique(); }); + llvm::orc::JITDylib *main_jd = session.getJITDylibByName("
"); + if (!main_jd) { + main_jd = &llvm::cantFail(session.createJITDylib("
")); + } + return object_layer; + }; + + VLOG(2) << "create jit execution engine"; + engine->jit_ = llvm::cantFail(llvm::orc::LLJITBuilder() + .setCompileFunctionCreator(compile_layer_creator) + .setObjectLinkingLayerCreator(object_layer_creator) + .create()); + engine->jit_->getMainJITDylib().addGenerator(llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(engine->jit_->getDataLayout().getGlobalPrefix()))); + + VLOG(2) << "register runtime call symbols"; + + engine->RegisterRuntimeSymbols(); + + VLOG(2) << "===================== Create CINN ExecutionEngine end ===================="; + return engine; +} + +template +void ExecutionEngine::Link(const ir::Module &module) { + utils::RecordEvent("ExecutionEngine Link", utils::EventType::kOrdinary); + llvm::SMDiagnostic error; + auto ctx = std::make_unique(); + auto m = llvm::parseAssemblyString(AsStringRef(backends::kRuntimeLlvmIr), error, *ctx); + auto b = std::make_unique>(*ctx); + auto ir_emitter = std::make_unique(m.get(), b.get()); + VLOG(3) << "ir_emitter->Compile(module) Begin"; + ir_emitter->Compile(module); + VLOG(3) << "ir_emitter->Compile(module) Succeed!"; + CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found"; + + auto machine = + std::move(llvm::cantFail(llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()).createTargetMachine())); + LLVMModuleOptimizer optimize(machine.get(), 3, {}, true); + optimize(m.get()); + CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid optimized module detected"; + for (auto &f : *m) { + VLOG(5) << "function: " << DumpToString(f); + } + + llvm::raw_svector_ostream rawstream(buffer_); + llvm::legacy::PassManager pass_manager; + machine->addPassesToEmitFile(pass_manager, rawstream, nullptr, llvm::CGFT_ObjectFile); + pass_manager.run(*m); + + CHECK(AddModule(std::move(m), std::move(ctx))); + + if (VLOG_IS_ON(5)) { + VLOG(5) << "======= dump jit execution session ======"; + std::string buffer; + llvm::raw_string_ostream os(buffer); + decltype(auto) es = jit_->getExecutionSession(); + es.dump(os); + os.flush(); + VLOG(5) << buffer; + } +} + +bool ExecutionEngine::AddModule(std::unique_ptr module, std::unique_ptr context) { + utils::RecordEvent("ExecutionEngine AddModule", utils::EventType::kOrdinary); + module->setDataLayout(jit_->getDataLayout()); + if (VLOG_IS_ON(5)) { + VLOG(5) << "======= dump jit lib =========="; + std::string buffer; + llvm::raw_string_ostream os(buffer); + module->print(os, {}); + // main_jd_->dump(os); + os.flush(); + VLOG(5) << buffer; + } + llvm::orc::ThreadSafeContext tsc(std::move(context)); + llvm::orc::ThreadSafeModule tsm(std::move(module), std::move(tsc)); + llvm::cantFail(jit_->addIRModule(std::move(tsm))); + return true; +} + +void ExecutionEngine::ExportObject(const std::string &path) { + FILE *of = fopen(path.c_str(), "w"); + fwrite(buffer_.data(), 1, buffer_.size(), of); + fclose(of); +} + +void *ExecutionEngine::Lookup(absl::string_view name) { + utils::RecordEvent("ExecutionEngine Lookup", utils::EventType::kOrdinary); + std::lock_guard lock(mu_); + if (auto symbol = jit_->lookup(AsStringRef(name))) { + return reinterpret_cast(symbol->getAddress()); + } + + LOG(ERROR) << "Unknown symbol name[" << name << "]"; + return nullptr; +} + +void ExecutionEngine::RegisterRuntimeSymbols() { + utils::RecordEvent("ExecutionEngine RegisterRuntimeSymbols", utils::EventType::kOrdinary); + const auto ®istry = GlobalSymbolRegistry::Global(); + auto *session = &jit_->getExecutionSession(); + for (const auto &sym : registry.All()) { + llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( + {{session->intern(sym.first), {llvm::pointerToJITTargetAddress(sym.second), llvm::JITSymbolFlags::None}}}))); + } + for (const auto &sym : module_symbols_.All()) { + llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( + {{session->intern(sym.first), {llvm::pointerToJITTargetAddress(sym.second), llvm::JITSymbolFlags::None}}}))); + } +} + +template void ExecutionEngine::Link(const ir::Module &module); +template void ExecutionEngine::Link(const ir::Module &module); +template void ExecutionEngine::Link(const ir::Module &module); + +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/execution_engine.h b/paddle/cinn/backends/llvm/execution_engine.h new file mode 100644 index 0000000000000..15a7e8793a139 --- /dev/null +++ b/paddle/cinn/backends/llvm/execution_engine.h @@ -0,0 +1,104 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "cinn/backends/llvm/codegen_x86.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/ir/module.h" + +namespace cinn::backends { + +class NaiveObjectCache : public llvm::ObjectCache { + public: + void notifyObjectCompiled(const llvm::Module *, llvm::MemoryBufferRef) override; + std::unique_ptr getObject(const llvm::Module *) override; + + private: + llvm::StringMap> cached_objects_; +}; + +struct ExecutionOptions { + int opt_level{3}; + bool enable_debug_info{false}; + // TODO(fc500110) + // int num_compile_threads{1}; + // bool enable_fast_math; +}; + +class ExecutionEngine { + public: + static std::unique_ptr Create(const ExecutionOptions &config); + + static std::unique_ptr Create(const ExecutionOptions &config, RuntimeSymbols &&module_symbols); + + void *Lookup(absl::string_view name); + + template + void Link(const ir::Module &module); + + void ExportObject(const std::string &path); + + bool AddModule(std::unique_ptr module, std::unique_ptr context); + + protected: + explicit ExecutionEngine(bool enable_object_cache, RuntimeSymbols &&module_symbols) + : cache_(std::make_unique()), module_symbols_(std::move(module_symbols)) {} + + void RegisterRuntimeSymbols(); + + bool SetupTargetTriple(llvm::Module *module); + + // This may not be a compatible implementation. + friend std::unique_ptr std::make_unique(bool &&, cinn::backends::RuntimeSymbols &&); + + private: + mutable std::mutex mu_; + llvm::SmallString<0> buffer_; + std::unique_ptr jit_; + std::unique_ptr cache_; + RuntimeSymbols module_symbols_; +}; + +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/execution_engine_test.cc b/paddle/cinn/backends/llvm/execution_engine_test.cc new file mode 100644 index 0000000000000..5818f33a645a8 --- /dev/null +++ b/paddle/cinn/backends/llvm/execution_engine_test.cc @@ -0,0 +1,329 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/execution_engine.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/backends/llvm/cinn_runtime_llvm_ir.h" +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/module.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/optim/optimize.h" +#include "cinn/runtime/cpu/host_intrinsics.h" +#include "cinn/runtime/cpu/use_extern_funcs.h" + +namespace cinn { +namespace backends { + +namespace { +bool RegisterKnownSymbols() { + decltype(auto) registry = GlobalSymbolRegistry::Global(); + + registry.RegisterFn("sinf", reinterpret_cast(&sinf)); + registry.RegisterFn("sin", reinterpret_cast(static_cast(&sin))); + + registry.RegisterFn("cosf", reinterpret_cast(&cosf)); + registry.RegisterFn("cos", reinterpret_cast(static_cast(&cos))); + return true; +} + +[[maybe_unused]] bool unused = RegisterKnownSymbols(); + +constexpr int kM = 100; +constexpr int kN = 32; + +auto CreateTestBuffer() { + auto *A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + auto *B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + auto *C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {kM, kN}, 32); + cinn_buffer_malloc(nullptr, A); + cinn_buffer_malloc(nullptr, B); + cinn_buffer_malloc(nullptr, C); + float *Ad = reinterpret_cast(A->memory); + float *Bd = reinterpret_cast(B->memory); + + for (int i = 0; i < A->num_elements(); i++) { + Ad[i] = static_cast(rand()) / RAND_MAX; // NOLINT + Bd[i] = static_cast(rand()) / RAND_MAX; // NOLINT + } + + float *Cd = reinterpret_cast(C->memory); + CHECK_EQ(C->num_elements(), A->num_elements()); + + return std::make_tuple(A, B, C); +} + +auto CreateTestCinnModule() { + ir::Expr M(kM); + ir::Expr N(kN); + lang::Placeholder A("A", {M, N}); + lang::Placeholder B("B", {M, N}); + + lang::Buffer C_buf(Float(32)); + auto C = lang::Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + C->Bind(C_buf); + + common::Target target; + target.arch = common::Target::Arch::X86; + target.bits = common::Target::Bit::k32; + target.os = common::Target::OS::Linux; + ir::Module::Builder builder("module1", target); + + auto stages = CreateStages({C}); + auto funcs = lang::Lower("elementwise_add", stages, {A, B, C}); + + // auto func = optim::Optimize(funcs); + + builder.AddFunction(ir::LoweredFunc(funcs.As())); + return builder.Build(); +} +} // namespace + +TEST(llvm_test01, elementwise_add) { + return; + auto engine = backends::ExecutionEngine::Create({1}); + + auto _a_b_c_ = CreateTestBuffer(); // NOLINT + auto &a = std::get<0>(_a_b_c_); + auto &b = std::get<1>(_a_b_c_); + auto &c = std::get<2>(_a_b_c_); + + auto module = CreateTestCinnModule(); + + engine->Link(module); + + auto elementwise_add_addr = engine->Lookup("elementwise_add"); + return; + auto elementwise_add = reinterpret_cast(elementwise_add_addr); + cinn_pod_value_t a_arg(a), b_arg(b), c_arg(c); + cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; + elementwise_add(args, 3); + + float *ad = reinterpret_cast(a->memory); + float *bd = reinterpret_cast(b->memory); + float *cd = reinterpret_cast(c->memory); + + for (int i = 0; i < c->num_elements(); i++) { + EXPECT_EQ(ad[i] + bd[i], cd[i]); + } +} + +TEST(llvm, module_call_lowered_func) { + ir::Module::Builder builder("some_module", common::DefaultHostTarget()); + ir::Expr M(kM); + ir::Expr N(kN); + { // define fn + lang::Placeholder a("A", {M, N}); + lang::Placeholder b("B", {M, N}); + auto c = lang::Compute( + {M, N}, [&](auto i, auto j) { return a(i, j) + b(i, j); }, "C"); + + auto stages = CreateStages({c}); + auto fn = lang::Lower("elementwise_add", stages, {a, b, c}, {}); + builder.AddFunction(fn); + } + + { // call fn + lang::Placeholder a("A", {M, N}); + lang::Placeholder b("B", {M, N}); + + std::vector ret_types({lang::ReturnType{Float(32), {M, N}, "c_out"}}); + + auto call_outs = lang::CallLowered("elementwise_add", {a, b}, ret_types); + auto c = call_outs[0]; + + // here we must call the output, so that it cal output something. + + auto stages = CreateStages({c}); + auto main_fn = lang::Lower("main", stages, {a, b, c}, {}); + builder.AddFunction(main_fn); + + CodeGenC codegen(common::DefaultHostTarget()); + codegen.SetInlineBuiltinCodes(false); + LOG(INFO) << "module:\n" << codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + } + + auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT + auto &ab = std::get<0>(_ab_bb_cb_); + auto &bb = std::get<1>(_ab_bb_cb_); + auto &cb = std::get<2>(_ab_bb_cb_); + do { // call the function + auto engine = backends::ExecutionEngine::Create({1}); + + LOG(INFO) << "JIT Link the module"; + engine->Link(builder.Build()); + auto cos_fn = (double (*)(double))engine->Lookup("cos"); + LOG(INFO) << "=> LLVM JIT cos(0) = " << cos_fn(0); + auto elementwise_add_addr = engine->Lookup("elementwise_add"); + auto elementwise_add = reinterpret_cast(elementwise_add_addr); + LOG(INFO) << "JIT get elementwise_add_addr"; + break; + + cinn_pod_value_t a_arg(ab), b_arg(bb), c_arg(cb); + cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; + + elementwise_add(args, 3); + + auto *ad = reinterpret_cast(ab->memory); + auto *bd = reinterpret_cast(bb->memory); + for (int i = 0; i < kM; i++) { + for (int j = 0; j < kN; j++) { + auto *data = reinterpret_cast(cb->memory); + ASSERT_NEAR(data[i * kN + j], ad[i * kN + j] + bd[i * kN + j], 1e-5); + } + } + } while (false); +} + +TEST(ExecutionEngine, custom_runtime_symbols) { + auto context = std::make_unique(); + auto module = std::make_unique("test_llvm_cpu_runtime", *context); + auto builder = std::make_unique>(*context); + + auto call_custom_target = [&](std::string name, llvm::Type *ty) { + llvm::FunctionType *fn_type = llvm::FunctionType::get(ty, {ty}, false); + llvm::Function *function = + llvm::Function::Create(fn_type, llvm::Function::ExternalLinkage, "_call_custom_" + name, module.get()); + function->setCallingConv(llvm::CallingConv::C); + llvm::BasicBlock *entry = llvm::BasicBlock::Create(module->getContext(), "entry", function); + builder->SetInsertPoint(entry); + llvm::Argument *arg = &*function->args().begin(); + llvm::Function *custom_function = + llvm::dyn_cast(module->getOrInsertFunction(name, fn_type).getCallee()); + custom_function->setCallingConv(llvm::CallingConv::C); + llvm::Value *ret = builder->CreateCall(custom_function, {arg}); + builder->CreateRet(ret); + }; + + llvm::Type *f32 = builder->getFloatTy(); + llvm::Type *f64 = builder->getDoubleTy(); + call_custom_target("cosf", f32); + call_custom_target("cos", f64); + call_custom_target("sinf", f32); + call_custom_target("sin", f64); + + double pi = std::acos(-1); + + std::vector angle = {0., pi / 6., pi / 4., pi / 3., pi / 2., pi}; + + std::random_device rd; + std::mt19937 mt(rd()); + std::uniform_int_distribution dis(-100, 100); + int random_x = dis(mt); + int random_y = dis(mt); + + decltype(auto) registry = GlobalSymbolRegistry::Global(); + // registry.Register("dereference_f64_ptr", (void *)+[](double *x) { return *x; }); + + for (size_t i = 0; i < angle.size(); i++) { + registry.RegisterVar("theta_" + std::to_string(i), angle[i]); + } + + auto engine = cinn::backends::ExecutionEngine::Create({1}); + engine->AddModule(std::move(module), std::move(context)); + + auto *call_cosf = reinterpret_cast(engine->Lookup("_call_custom_cosf")); + auto *call_cos = reinterpret_cast(engine->Lookup("_call_custom_cos")); + auto *call_sinf = reinterpret_cast(engine->Lookup("_call_custom_sinf")); + auto *call_sin = reinterpret_cast(engine->Lookup("_call_custom_sin")); + + ASSERT_TRUE(call_cosf && call_cos && call_sinf && call_sin); + + for (auto theta : angle) { + float theta_f = static_cast(theta); + ASSERT_NEAR(call_cosf(theta_f), cosf(theta_f), 1e-6); + ASSERT_NEAR(call_cos(theta), cos(theta), 1e-6); + ASSERT_NEAR(call_sinf(theta_f), sinf(theta_f), 1e-6); + ASSERT_NEAR(call_sin(theta), sin(theta), 1e-6); + } +} + +TEST(ExecutionEngine, call_extern) { + ir::Expr M(kM); + ir::Expr N(kN); + + Placeholder x("x", {M, N}); + Placeholder y("y", {M, N}); + + auto add_out = Compute( + {M, N}, [=](Var i, Var j) { return x(i, j) + y(i, j); }, "add_out"); + + ir::Tensor res = Compute( + {M, N}, [&](Var i, Var j) -> Expr { return lang::CallExtern("tanh", {add_out(i, j)}); }, "res"); + + auto stages = CreateStages({add_out, res}); + + stages[add_out]->ComputeInline(); + auto func = Lower("comp", stages, {x, y, res}); + + Module::Builder builder("module0", common::DefaultHostTarget()); + builder.AddFunction(func); + + auto engine = backends::ExecutionEngine::Create({1}); + + engine->Link(builder.Build()); + + auto _ab_bb_cb_ = CreateTestBuffer(); // NOLINT + auto &ab = std::get<0>(_ab_bb_cb_); + auto &bb = std::get<1>(_ab_bb_cb_); + auto &cb = std::get<2>(_ab_bb_cb_); + + auto comp_addr = engine->Lookup("comp"); + auto comp = reinterpret_cast(comp_addr); + + cinn_pod_value_t a_arg(ab), b_arg(bb), c_arg(cb); + cinn_pod_value_t args[3] = {a_arg, b_arg, c_arg}; + + comp(args, 3); + + auto *ad = reinterpret_cast(ab->memory); + auto *bd = reinterpret_cast(bb->memory); + auto *cd = reinterpret_cast(cb->memory); + for (int m = 0; m < kM; m++) { + for (int n = 0; n < kN; n++) { + ASSERT_NEAR(cd[m * kN + n], tanh(ad[m * kN + n] + bd[m * kN + n]), 1e-5); + } + } +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py b/paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py new file mode 100644 index 0000000000000..2d8d93aa5d334 --- /dev/null +++ b/paddle/cinn/backends/llvm/generate_runtime_llvm_ir.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import subprocess + + +def main(): + path = sys.argv[1] + out_path = sys.argv[2] + llvm_config = sys.argv[3] + + srcs = [] + srcs.append('#include ') + #srcs.append('#include "cinn/backends/llvm/cinn_runtime_llvm_ir.h"\n') + srcs.append('namespace cinn::backends {') + srcs.append("static const absl::string_view kRuntimeLlvmIr(") + srcs.append('R"ROC(') + with open(path, 'r') as fr: + srcs.append(fr.read()) + + srcs.append(')ROC"') + srcs.append(');\n') + + cmd = "{} --version".format(llvm_config) + version = subprocess.check_output( + cmd, shell=True).decode('utf-8').strip().split('.') + srcs.append("struct llvm_version {") + for v, n in zip(["major", "minor", "micro"], version): + srcs.append(" static constexpr int k{} = {};".format( + v.title(), ''.join(filter(str.isdigit, n)))) + srcs.append("};") + + srcs.append('} // namespace cinn::backends') + with open(out_path, 'w') as fw: + fw.write("\n".join(srcs)) + + +def get_clang_version(): + pass + + +if __name__ == "__main__": + main() diff --git a/paddle/cinn/backends/llvm/ir_builder_mixin.h b/paddle/cinn/backends/llvm/ir_builder_mixin.h new file mode 100644 index 0000000000000..42b1e9663afbb --- /dev/null +++ b/paddle/cinn/backends/llvm/ir_builder_mixin.h @@ -0,0 +1,306 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +namespace cinn { +namespace backends { +template +class IrBuilderMixin { + protected: + template + decltype(auto) BinOp(Args &&...args) { + return mixin_builder()->CreateBinOp(std::forward(args)...); + } + + /// \brief + + template + decltype(auto) Add(Args &&...args) { + return mixin_builder()->CreateAdd(std::forward(args)...); + } + template + decltype(auto) FAdd(Args &&...args) { + return mixin_builder()->CreateFAdd(std::forward(args)...); + } + template + decltype(auto) NSWAdd(Args &&...args) { + return mixin_builder()->CreateNSWAdd(std::forward(args)...); + } + + /// \brief - + template + decltype(auto) Sub(Args &&...args) { + return mixin_builder()->CreateSub(std::forward(args)...); + } + template + decltype(auto) FSub(Args &&...args) { + return mixin_builder()->CreateFSub(std::forward(args)...); + } + template + decltype(auto) NSWSub(Args &&...args) { + return mixin_builder()->CreateNSWSub(std::forward(args)...); + } + + /// \brief * + template + decltype(auto) Mul(Args &&...args) { + return mixin_builder()->CreateMul(std::forward(args)...); + } + template + decltype(auto) FMul(Args &&...args) { + return mixin_builder()->CreateFMul(std::forward(args)...); + } + template + decltype(auto) NSWMul(Args &&...args) { + return mixin_builder()->CreateNSWMul(std::forward(args)...); + } + + /// \brief / + template + decltype(auto) SDiv(Args &&...args) { + return mixin_builder()->CreateSDiv(std::forward(args)...); + } + template + decltype(auto) UDiv(Args &&...args) { + return mixin_builder()->CreateUDiv(std::forward(args)...); + } + template + decltype(auto) FDiv(Args &&...args) { + return mixin_builder()->CreateFDiv(std::forward(args)...); + } + + /// \brief % + template + decltype(auto) SRem(Args &&...args) { + return mixin_builder()->CreateSRem(std::forward(args)...); + } + template + decltype(auto) URem(Args &&...args) { + return mixin_builder()->CreateURem(std::forward(args)...); + } + template + decltype(auto) FRem(Args &&...args) { + return mixin_builder()->CreateFRem(std::forward(args)...); + } + + template + decltype(auto) And(Args &&...args) { + return mixin_builder()->CreateAnd(std::forward(args)...); + } + template + decltype(auto) Or(Args &&...args) { + return mixin_builder()->CreateOr(std::forward(args)...); + } + template + decltype(auto) Not(Args &&...args) { + return mixin_builder()->CreateNot(std::forward(args)...); + } + + template + decltype(auto) Neg(Args &&...args) { + return mixin_builder()->CreateNeg(std::forward(args)...); + } + template + decltype(auto) FNeg(Args &&...args) { + return mixin_builder()->CreateFNeg(std::forward(args)...); + } + + template + decltype(auto) ICmpEQ(Args &&...args) { + return mixin_builder()->CreateICmpEQ(std::forward(args)...); + } + template + decltype(auto) FCmpOEQ(Args &&...args) { + return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); + } + template + decltype(auto) FCmpUEQ(Args &&...args) { + return mixin_builder()->CreateFCmpUEQ(std::forward(args)...); + } + template + decltype(auto) ICmpNE(Args &&...args) { + return mixin_builder()->CreateICmpNE(std::forward(args)...); + } + template + decltype(auto) FCmpONE(Args &&...args) { + return mixin_builder()->CreateFCmpONE(std::forward(args)...); + } + template + decltype(auto) FCmpUNE(Args &&...args) { + return mixin_builder()->CreateFCmpUNE(std::forward(args)...); + } + template + decltype(auto) ICmpULE(Args &&...args) { + return mixin_builder()->CreateICmpULE(std::forward(args)...); + } + template + decltype(auto) FCmpOLE(Args &&...args) { + return mixin_builder()->CreateFCmpOLE(std::forward(args)...); + } + template + decltype(auto) ICmpULT(Args &&...args) { + return mixin_builder()->CreateICmpULT(std::forward(args)...); + } + template + decltype(auto) ICmpSLT(Args &&...args) { + return mixin_builder()->CreateICmpSLT(std::forward(args)...); + } + template + decltype(auto) FCmpOLT(Args &&...args) { + return mixin_builder()->CreateFCmpOLT(std::forward(args)...); + } + template + decltype(auto) ICmpUGE(Args &&...args) { + return mixin_builder()->CreateICmpUGE(std::forward(args)...); + } + template + decltype(auto) ICmpSGE(Args &&...args) { + return mixin_builder()->CreateICmpSGE(std::forward(args)...); + } + template + decltype(auto) FCmpOGE(Args &&...args) { + return mixin_builder()->CreateFCmpOGE(std::forward(args)...); + } + template + decltype(auto) ICmpUGT(Args &&...args) { + return mixin_builder()->CreateICmpUGT(std::forward(args)...); + } + template + decltype(auto) ICmpSGT(Args &&...args) { + return mixin_builder()->CreateICmpSGT(std::forward(args)...); + } + template + decltype(auto) FCmpOGT(Args &&...args) { + return mixin_builder()->CreateFCmpOGT(std::forward(args)...); + } + + template + decltype(auto) BitCast(Args &&...args) { + return mixin_builder()->CreateBitCast(std::forward(args)...); + } + template + decltype(auto) IntCast(Args &&...args) { + return mixin_builder()->CreateIntCast(std::forward(args)...); + } + template + decltype(auto) FPCast(Args &&...args) { + return mixin_builder()->CreateFPCast(std::forward(args)...); + } + template + decltype(auto) PointerCast(Args &&...args) { + return mixin_builder()->CreatePointerCast(std::forward(args)...); + } + + template + decltype(auto) FPToSI(Args &&...args) { + return mixin_builder()->CreateFPToSI(std::forward(args)...); + } + template + decltype(auto) FPToUI(Args &&...args) { + return mixin_builder()->CreateFPToUI(std::forward(args)...); + } + template + decltype(auto) SIToFP(Args &&...args) { + return mixin_builder()->CreateSIToFP(std::forward(args)...); + } + template + decltype(auto) UIToFP(Args &&...args) { + return mixin_builder()->CreateUIToFP(std::forward(args)...); + } + + template + decltype(auto) Select(Args &&...args) { + return mixin_builder()->CreateSelect(std::forward(args)...); + } + template + decltype(auto) Br(Args &&...args) { + return mixin_builder()->CreateBr(std::forward(args)...); + } + template + decltype(auto) CondBr(Args &&...args) { + return mixin_builder()->CreateCondBr(std::forward(args)...); + } + + template + decltype(auto) Alloca(Args &&...args) { + return mixin_builder()->CreateAlloca(std::forward(args)...); + } + template + decltype(auto) Load(Args &&...args) { + return mixin_builder()->CreateLoad(std::forward(args)...); + } + template + decltype(auto) AlignedLoad(Args &&...args) { + return mixin_builder()->CreateAlignedLoad(std::forward(args)...); + } + template + decltype(auto) Store(Args &&...args) { + return mixin_builder()->CreateStore(std::forward(args)...); + } + template + decltype(auto) AlignedStore(Args &&...args) { + return mixin_builder()->CreateAlignedStore(std::forward(args)...); + } + template + decltype(auto) Call(Args &&...args) { + return mixin_builder()->CreateCall(std::forward(args)...); + } + + template + decltype(auto) RetVoid(Args &&...args) { + return mixin_builder()->CreateRetVoid(std::forward(args)...); + } + template + decltype(auto) GEP(Args &&...args) { + return mixin_builder()->CreateGEP(std::forward(args)...); + } + template + decltype(auto) InBoundsGEP(Args &&...args) { + return mixin_builder()->CreateInBoundsGEP(std::forward(args)...); + } + template + decltype(auto) PHI(Args &&...args) { + return mixin_builder()->CreatePHI(std::forward(args)...); + } + + template + decltype(auto) InsertValue(Args &&...args) { + return mixin_builder()->CreateInsertValue(std::forward(args)...); + } + + template + decltype(auto) ExtractValue(Args &&...args) { + return mixin_builder()->CreateExtractValue(std::forward(args)...); + } + + template + decltype(auto) InsertElement(Args &&...args) { + return mixin_builder()->CreateInsertElement(std::forward(args)...); + } + + template + decltype(auto) ShuffleVector(Args &&...args) { + return mixin_builder()->CreateShuffleVector(std::forward(args)...); + } + + private: + llvm::IRBuilder<> *mixin_builder() { return static_cast(this)->b(); } +}; +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/llvm_intrin_rule.h b/paddle/cinn/backends/llvm/llvm_intrin_rule.h new file mode 100644 index 0000000000000..822349f8a8ae9 --- /dev/null +++ b/paddle/cinn/backends/llvm/llvm_intrin_rule.h @@ -0,0 +1,177 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/registry.h" +#include "cinn/lang/packed_func.h" + +namespace cinn { +namespace codegen { + +template +inline void MakeFloatIntrinOp(lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg = args[0]; + ir::Call *node = arg->as(); + CHECK(node); + CHECK_GE(node->read_args.size(), arg_nums); + if (add_float_suffix) { + CHECK(node->type().is_float()); + *rv = ir::intrinsics::BuiltinIntrin::Make(node->name + "f", node->read_args, id, arg_nums, node->type()); + } else { + *rv = ir::intrinsics::BuiltinIntrin::Make(node->name, node->read_args, id, arg_nums, node->type()); + } +} + +void RegisterCpuIntrinRule() { +#define __(intrin_name__, id) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp); + __(exp, ::llvm::Intrinsic::exp) + __(exp2, ::llvm::Intrinsic::exp2) + __(sqrt, ::llvm::Intrinsic::sqrt) + __(log, ::llvm::Intrinsic::log) + __(log2, ::llvm::Intrinsic::log2) + __(log10, ::llvm::Intrinsic::log10) + __(floor, ::llvm::Intrinsic::floor) + __(ceil, ::llvm::Intrinsic::ceil) + __(round, ::llvm::Intrinsic::round) + __(trunc, ::llvm::Intrinsic::trunc) + __(cos, ::llvm::Intrinsic::cos) + __(sin, ::llvm::Intrinsic::sin) + __(fabs, ::llvm::Intrinsic::fabs) +#undef __ + +// set id -1 if not llvm intrinsics +#define RegisterBitwise(intrin_name__) \ + ir::Registry::Register("lower_cpu_intrinsic_" #intrin_name__, true).SetBody(MakeFloatIntrinOp<-1, 2, false>); + RegisterBitwise(bitwise_or) RegisterBitwise(bitwise_xor) RegisterBitwise(bitwise_and) RegisterBitwise(left_shift) + RegisterBitwise(right_shift) +#undef RegisterBitwise + + ir::Registry::Register("lower_cpu_intrinsic_fma", true) + .SetBody(MakeFloatIntrinOp<::llvm::Intrinsic::fmuladd, 3, false>); + + ir::Registry::Register("lower_cpu_intrinsic_bitwise_not", true).SetBody(MakeFloatIntrinOp<-1, 1, false>); + + ir::Registry::Register("lower_cpu_intrinsic_isnan", true).SetBody(MakeFloatIntrinOp<-1, 1, false>); + + ir::Registry::Register("lower_cpu_intrinsic_isfinite", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = !(lang::IsInf(arg)) && !(lang::IsNan(arg)); + }); + + ir::Registry::Register("lower_cpu_intrinsic_isinf", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Type type = arg->type(); + if (type.is_int() || type.is_uint()) { + *rv = common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + *rv = ir::EQ::Make(lang::Abs(arg), lang::Infinity(type)) && !(lang::IsNan(arg)); + } + }); + + ir::Registry::Register("lower_cpu_intrinsic_rsqrt", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = make_const(arg->type(), 1) / lang::Sqrt(arg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_exp10", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr ln10 = make_const(arg->type(), 2.302585093); + *rv = lang::Exp(arg * ln10); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tan", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = lang::Sin(arg) / lang::Cos(arg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_tanh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + Expr zero = make_const(arg->type(), 0); + Expr one = make_const(arg->type(), 1); + Expr two = make_const(arg->type(), 2); + Expr neg_two = make_const(arg->type(), -2); + + Expr exp_neg2x = lang::Exp(neg_two * arg); + Expr exp_pos2x = lang::Exp(two * arg); + + Expr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + Expr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = ir::Select::Make(arg >= zero, tanh_pos, tanh_neg); + }); + + ir::Registry::Register("lower_cpu_intrinsic_cosh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) + lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); + }); + + ir::Registry::Register("lower_cpu_intrinsic_sinh", true).SetBody([](lang::Args args, lang::RetValue *rv) { + CHECK_GE(args.size(), 1U); + Expr arg0 = args[0]; + ir::Call *node = arg0->as(); + CHECK(node); + CHECK(!node->read_args.empty()); + Expr arg = node->read_args[0]; + *rv = (lang::Exp(arg) - lang::Exp(arg * make_const(arg->type(), -1))) / make_const(arg->type(), 2); + }); +} +} // namespace codegen +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/llvm_optimizer.cc b/paddle/cinn/backends/llvm/llvm_optimizer.cc new file mode 100644 index 0000000000000..ff5c60d74fd7a --- /dev/null +++ b/paddle/cinn/backends/llvm/llvm_optimizer.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/llvm_optimizer.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "llvm/Support/CodeGen.h" + +namespace cinn::backends { + +namespace { +template +class CustomPassManager : public PassManagerT { + public: + template + explicit CustomPassManager(bool print_passes, Ts &&...ts) + : PassManagerT(std::forward(ts)...), print_passes_(print_passes) {} + + void add(llvm::Pass *pass) override { + if (print_passes_) { + if (is_function_pass_manager_) { + VLOG(1) << "llvm run function pass[" << std::string(pass->getPassName()) << "]"; + } + + if (is_module_pass_manager_) { + VLOG(1) << "llvm run module pass[" << std::string(pass->getPassName()) << "]"; + } + } + // static bool add_pass = true; + // if (add_pass) { + // PassManagerT::add(pass); + //} + + // if (std::string(pass->getPassName()) == "Loop Vectorization") { + // return; + //} + PassManagerT::add(pass); + } + + void run(llvm::Function &f) { // NOLINT + if (is_function_pass_manager_) { + PassManagerT::run(f); + } + } + + void run(llvm::Module &m) { // NOLINT + if (is_module_pass_manager_) { + PassManagerT::run(m); + } + } + + private: + static constexpr bool is_function_pass_manager_ = + std::is_same::value; + static constexpr bool is_module_pass_manager_ = std::is_same::value; + bool print_passes_; +}; + +using CustomFunctionPassManager = CustomPassManager; +using CustomModulePassManager = CustomPassManager; +} // namespace + +LLVMModuleOptimizer::LLVMModuleOptimizer(llvm::TargetMachine *machine, + int opt_level, + llvm::FastMathFlags fast_math_flags, + bool print_passes) + : opt_level_(opt_level), print_passes_(print_passes), machine_(machine) {} + +void LLVMModuleOptimizer::operator()(llvm::Module *m) { + auto machine = + std::move(llvm::cantFail(llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()).createTargetMachine())); + auto fpm = std::make_unique(print_passes_, m); + // fpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis())); + // fpm->add(llvm::createInstructionCombiningPass()); + // fpm->add(llvm::createReassociatePass()); + // fpm->add(llvm::createGVNPass()); + // fpm->add(llvm::createCFGSimplificationPass()); + // fpm->add(llvm::createSROAPass()); + // fpm->add(llvm::createEarlyCSEPass()); + // fpm->add(llvm::createLowerExpectIntrinsicPass()); + // fpm->add(llvm::createCallSiteSplittingPass()); + // fpm->add(llvm::createLoopVectorizePass()); + // fpm->add(llvm::createSLPVectorizerPass()); + // fpm->add(llvm::createLoadStoreVectorizerPass()); + // fpm->add(llvm::createLoopUnrollPass()); + + auto mpm = std::make_unique(print_passes_); + // mpm->add(llvm::createTargetTransformInfoWrapperPass(llvm::TargetIRAnalysis())); + // LOG(INFO) << "llvm run pass: target machine: name[" << machine_->getTarget().getName() << "]"; + // LOG(INFO) << "llvm run pass: target machine: cpu[" << machine_->getTargetCPU().str() << "]"; + fpm->add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); + mpm->add(llvm::createTargetTransformInfoWrapperPass(machine->getTargetIRAnalysis())); + auto builder = std::make_unique(); + builder->OptLevel = opt_level_; + builder->Inliner = llvm::createFunctionInliningPass(); + builder->LoopVectorize = true; + builder->SLPVectorize = true; +#if LLVM_VERSION_MAJOR >= 11 + machine->adjustPassManager(*builder); +#endif + builder->populateFunctionPassManager(*fpm); + builder->populateModulePassManager(*mpm); + + fpm->doInitialization(); + std::for_each(m->begin(), m->end(), [&fpm](auto &fn) { fpm->run(fn); }); + fpm->doFinalization(); + + mpm->run(*m); +} + +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/llvm_optimizer.h b/paddle/cinn/backends/llvm/llvm_optimizer.h new file mode 100644 index 0000000000000..ea613c1da0b2b --- /dev/null +++ b/paddle/cinn/backends/llvm/llvm_optimizer.h @@ -0,0 +1,43 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace cinn::backends { + +// TODO(fc500110): define class OptimizeOptions + +// llvm module optimizer +class LLVMModuleOptimizer final { + public: + explicit LLVMModuleOptimizer(llvm::TargetMachine *machine, + int opt_level, + llvm::FastMathFlags fast_math_flags, + bool print_passes = false); + void operator()(llvm::Module *m); + + private: + llvm::TargetMachine *machine_; + int opt_level_{}; + bool print_passes_{}; +}; +} // namespace cinn::backends diff --git a/paddle/cinn/backends/llvm/llvm_util.cc b/paddle/cinn/backends/llvm/llvm_util.cc new file mode 100644 index 0000000000000..e03325faf4d21 --- /dev/null +++ b/paddle/cinn/backends/llvm/llvm_util.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/llvm_util.h" + +#include +#include + +#include +#include //NOLINT + +namespace cinn { +namespace backends { + +using cinn::common::bfloat16; +using cinn::common::float16; + +llvm::Type *CinnTypeToLLVMType(common::Type type, llvm::Module *m, bool is_vec) { + llvm::Type *ir_type = nullptr; + if (type.is_cpp_const()) { + // TODO(fc500110) support it latter. + } + + llvm::Type *v = llvm::Type::getVoidTy(m->getContext()); + + llvm::Type *i1 = llvm::Type::getInt1Ty(m->getContext()); + + llvm::Type *i8 = llvm::Type::getInt8Ty(m->getContext()); + llvm::Type *i16 = llvm::Type::getInt16Ty(m->getContext()); + llvm::Type *i32 = llvm::Type::getInt32Ty(m->getContext()); + llvm::Type *i64 = llvm::Type::getInt64Ty(m->getContext()); + + llvm::Type *u8 = llvm::Type::getInt8Ty(m->getContext()); + llvm::Type *u16 = llvm::Type::getInt16Ty(m->getContext()); + llvm::Type *u32 = llvm::Type::getInt32Ty(m->getContext()); + llvm::Type *u64 = llvm::Type::getInt64Ty(m->getContext()); + + llvm::Type *bf16 = llvm::Type::getBFloatTy(m->getContext()); + llvm::Type *f16 = llvm::Type::getHalfTy(m->getContext()); + llvm::Type *f32 = llvm::Type::getFloatTy(m->getContext()); + llvm::Type *f64 = llvm::Type::getDoubleTy(m->getContext()); + llvm::Type *arr = llvm::Type::getPrimitiveType(m->getContext(), llvm::Type::ArrayTyID); + if (type.is_void() && type.is_cpp_handle()) { + return llvm::PointerType::getUnqual(i8); + } + if (type.is_void() && type.is_cpp_handle2()) { + return llvm::PointerType::getUnqual(llvm::PointerType::getUnqual(i8)); + } + + if (type.is_bool()) { + ir_type = i1; + } else if (type.is_int(8)) { + ir_type = i8; + } else if (type.is_int(16)) { + ir_type = i16; + } else if (type.is_int(32)) { + ir_type = i32; + } else if (type.is_int(64)) { + ir_type = i64; + } else if (type.is_uint(8)) { + ir_type = u8; + } else if (type.is_uint(16)) { + ir_type = u16; + } else if (type.is_uint(32)) { + ir_type = u32; + } else if (type.is_uint(64)) { + ir_type = u64; + } else if (type.is_float(32)) { + ir_type = f32; + } else if (type.is_float(64)) { + ir_type = f64; + } else if (type.is_bfloat16()) { + ir_type = bf16; + } else if (type.is_float16()) { + ir_type = f16; + } else if (type.is_void()) { + ir_type = v; + } else if (type.is_string()) { + ir_type = arr; + } else if (type.is_customized_type()) { + CHECK(!type.customized_type().empty()); + ir_type = m->getTypeByName("struct." + type.customized_type()); + } + CHECK(ir_type) << "LLVM can't convert type: " << type; + + // C array / vector. + if (type.lanes() > 1) { + if (is_vec) { + ir_type = llvm::FixedVectorType::get(ir_type, type.lanes()); + } else { + ir_type = llvm::ArrayType::get(ir_type, type.lanes()); + } + } + + if (type.is_cpp_handle()) { + ir_type = llvm::PointerType::getUnqual(ir_type); + } + + if (type.is_cpp_handle2()) { + ir_type = llvm::PointerType::getUnqual(ir_type); + ir_type = llvm::PointerType::getUnqual(ir_type); + } + + return ir_type; +} + +#define __(ty__) \ + template <> \ + llvm::Type *llvm_type_of(llvm::Module * m) { \ + return CinnTypeToLLVMType(common::type_of(), m); \ + } + +__(int8_t) +__(int16_t) +__(int32_t) +__(int64_t) +__(uint8_t) +__(uint16_t) +__(uint32_t) +__(uint64_t) +__(bfloat16) +__(float16) +__(float) +__(double) +__(cinn_buffer_t) +__(cinn_buffer_t *) +__(cinn_pod_value_t *) +__(cinn_pod_value_t) +__(void *) +__(void **) + +#undef __ + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/llvm_util.h b/paddle/cinn/backends/llvm/llvm_util.h new file mode 100644 index 0000000000000..b53b46af245d8 --- /dev/null +++ b/paddle/cinn/backends/llvm/llvm_util.h @@ -0,0 +1,55 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cinn/common/type.h" + +namespace cinn { +namespace backends { + +template +std::string DumpToString(const T &entity) { + std::string buffer; + llvm::raw_string_ostream os(buffer); + entity.print(os); + os.flush(); + return buffer; + // return "\033[33m" + buffer + "\033[0m"; // Green +} + +inline llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } + +llvm::Type *CinnTypeToLLVMType(common::Type t, llvm::Module *m, bool is_vec = false); + +template +llvm::Type *llvm_type_of(llvm::Module *m); + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/runtime_symbol_registry.cc b/paddle/cinn/backends/llvm/runtime_symbol_registry.cc new file mode 100644 index 0000000000000..796a7f9b69216 --- /dev/null +++ b/paddle/cinn/backends/llvm/runtime_symbol_registry.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/runtime_symbol_registry.h" + +#include +#include + +#include + +#include "cinn/runtime/flags.h" +#include "gflags/gflags_declare.h" + +DECLARE_bool(verbose_function_register); + +namespace cinn { +namespace backends { + +RuntimeSymbols &GlobalSymbolRegistry::Global() { + static RuntimeSymbols symbols; + return symbols; +} + +void *RuntimeSymbols::Lookup(absl::string_view name) const { + std::lock_guard lock(mu_); + auto it = symbols_.find(std::string(name)); + if (it != symbols_.end()) { + return it->second; + } + + return nullptr; +} + +void RuntimeSymbols::Register(const std::string &name, void *address) { +#ifdef CINN_WITH_DEBUG + if (FLAGS_verbose_function_register) { + RAW_LOG_INFO("JIT Register function [%s]: %p", name.c_str(), address); + } +#endif // CINN_WITH_DEBUG + std::lock_guard lock(mu_); + auto it = symbols_.find(name); + if (it != symbols_.end()) { + CHECK_EQ(it->second, address) << "Duplicate register symbol [" << name << "]"; + return; + } + + symbols_.insert({name, reinterpret_cast(address)}); +} + +void RuntimeSymbols::Clear() { + std::lock_guard lock(mu_); + symbols_.clear(); + scalar_holder_.clear(); +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/runtime_symbol_registry.h b/paddle/cinn/backends/llvm/runtime_symbol_registry.h new file mode 100644 index 0000000000000..91e82cb1ffad9 --- /dev/null +++ b/paddle/cinn/backends/llvm/runtime_symbol_registry.h @@ -0,0 +1,113 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include +#include // NOLINT +#include +#include + +#include "cinn/common/macros.h" + +namespace cinn { +namespace backends { + +class RuntimeSymbols { + public: + RuntimeSymbols() = default; + + RuntimeSymbols(const RuntimeSymbols &) = delete; + + RuntimeSymbols(RuntimeSymbols &&rhs) { + symbols_ = std::move(rhs.symbols_); + scalar_holder_ = std::move(rhs.scalar_holder_); + } + + /** + * Register function address. + * @param name Name of the symbol. + * @param address Address of the function. + */ + void RegisterFn(const std::string &name, void *address) { Register(name, address); } + + /** + * Register scalar. + * @tparam T Type of the scalar. + * @param name Name of the symbol. + * @param val Scalar value. + */ + template ::value>> + void RegisterVar(const std::string &name, T val) { + void *data_ptr = nullptr; + { + std::lock_guard lock(mu_); + auto &data = scalar_holder_[name]; + data.resize(sizeof(T)); + memcpy(data.data(), &val, sizeof(T)); + data_ptr = reinterpret_cast(data.data()); + } + Register(name, data_ptr); + } + + /** + * Lookup a symbol from the registry. + * @param name Name of the symbol. + * @return The address if existes, or nullptr will return. + */ + void *Lookup(absl::string_view name) const; + + /** + * Get all the symbols. + */ + const std::map &All() const { return symbols_; } + + /** + * Clear all the symbols. + */ + void Clear(); + + private: + /** + * Register external symbol to the registry, the symbols in the registry will finally registered to JIT . + * @param name Name of the symbol in the JIT. + * @param address The address of the variable in external space. + */ + void Register(const std::string &name, void *address); + + mutable std::mutex mu_; + std::map symbols_; + std::map> scalar_holder_; +}; + +/** + * Registry for runtime symbols, these symbols will be inserted into JIT. + */ + +class GlobalSymbolRegistry { + public: + static RuntimeSymbols &Global(); + + private: + GlobalSymbolRegistry() = default; + CINN_DISALLOW_COPY_AND_ASSIGN(GlobalSymbolRegistry); +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/simple_jit.cc b/paddle/cinn/backends/llvm/simple_jit.cc new file mode 100755 index 0000000000000..77f55e18644cd --- /dev/null +++ b/paddle/cinn/backends/llvm/simple_jit.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/llvm/simple_jit.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cinn/backends/codegen_cuda_host.h" +#include "cinn/backends/llvm/cinn_runtime_llvm_ir.h" +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace backends { + +void SimpleJIT::AddModule(std::unique_ptr module, bool optimize) { + /* + for (auto &fn : module->functions()) { + LOG(INFO) << "fn:\n" << DumpToString(fn); + } + */ + CHECK(!llvm::verifyModule(*module, &llvm::errs())) << "Transformation resulted in an invalid module\n\nmodule:\n"; + + bool debug = false; + if (optimize) { + llvm::PassBuilder pass_builder; + llvm::LoopAnalysisManager loop_analysis_manager(debug); + llvm::FunctionAnalysisManager function_analysis_manager(debug); + llvm::CGSCCAnalysisManager cgscc_analysis_manager(debug); + llvm::ModuleAnalysisManager module_analysis_manager(debug); + + pass_builder.registerModuleAnalyses(module_analysis_manager); + pass_builder.registerCGSCCAnalyses(cgscc_analysis_manager); + pass_builder.registerFunctionAnalyses(function_analysis_manager); + pass_builder.registerLoopAnalyses(loop_analysis_manager); + pass_builder.crossRegisterProxies( + loop_analysis_manager, function_analysis_manager, cgscc_analysis_manager, module_analysis_manager); + + llvm::ModulePassManager module_pass_manager = + pass_builder.buildPerModuleDefaultPipeline(llvm::PassBuilder::OptimizationLevel::O3); + module_pass_manager.run(*module, module_analysis_manager); + } + + VLOG(3) << "jit target: " << jit_->getDataLayout().getStringRepresentation(); + VLOG(3) << "module target: " << module->getDataLayout().getStringRepresentation(); + + llvm::orc::ThreadSafeModule tsm(std::move(module), context_); + llvm::cantFail(jit_->addIRModule(std::move(tsm))); + + if (debug) { + std::string buffer; + llvm::raw_string_ostream os(buffer); + jit_->getExecutionSession().dump(os); + os.flush(); + VLOG(3) << "compiled jit:\n" << buffer; + } +} + +SimpleJIT::SimpleJIT() : context_(std::make_unique()) { + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + + jit_ = llvm::cantFail(llvm::orc::LLJITBuilder().create()); + CHECK(jit_) << "JIT create failed"; + + auto proc_symbols_generator = llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(jit_->getDataLayout().getGlobalPrefix())); + jit_->getMainJITDylib().addGenerator(std::move(proc_symbols_generator)); + + llvm::orc::MangleAndInterner mangle(jit_->getExecutionSession(), jit_->getDataLayout()); + + for (auto &item : GlobalSymbolRegistry::Global().All()) { + VLOG(2) << "Insert [" << item.first << "] to SimpleJIT"; + llvm::cantFail(jit_->define(llvm::orc::absoluteSymbols( + {{mangle(item.first), {llvm::pointerToJITTargetAddress(item.second), llvm::JITSymbolFlags::None}}}))); + } +} + +template +void SimpleJIT::Link(ir::Module module, bool optimize) { + std::string runtime_ir(backends::kRuntimeLlvmIr); + llvm::SMDiagnostic error; + auto m = llvm::parseAssemblyString(runtime_ir, error, context()); + m->setDataLayout(jit_->getDataLayout()); + auto b = std::make_unique>(context()); + + auto ir_emitter = std::make_unique(m.get(), b.get()); + ir_emitter->Compile(module); + + CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found"; + + AddModule(std::move(m), optimize); +} + +template void SimpleJIT::Link(ir::Module module, bool optimize); +template void SimpleJIT::Link(ir::Module module, bool optimize); + +} // namespace backends + +} // namespace cinn diff --git a/paddle/cinn/backends/llvm/simple_jit.h b/paddle/cinn/backends/llvm/simple_jit.h new file mode 100755 index 0000000000000..ebbae127c3d8e --- /dev/null +++ b/paddle/cinn/backends/llvm/simple_jit.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "cinn/backends/llvm/codegen_llvm.h" +#include "cinn/backends/llvm/llvm_util.h" +#include "cinn/backends/llvm/runtime_symbol_registry.h" +#include "cinn/ir/module.h" +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace backends { + +class SimpleJIT { + public: + static std::unique_ptr Create() { return std::unique_ptr(new SimpleJIT); } + + /** + * Runtime link to a module. + * @tparam CodeGenT a CodeGenLLVM implementation. + * @param module a CINN module. + * @param optimize whether to optimize. + */ + template + void Link(ir::Module module, bool optimize = true); + + void Link(llvm::orc::ThreadSafeModule m, bool optimize = true) { llvm::cantFail(jit_->addIRModule(std::move(m))); } + + llvm::JITTargetAddress Lookup(absl::string_view name) { + return llvm::cantFail(jit_->lookup(AsStringRef(name))).getAddress(); + } + + private: + void AddModule(std::unique_ptr module, bool optimize); + + llvm::LLVMContext &context() { return *context_.getContext(); } + + SimpleJIT(); + + std::unique_ptr jit_; + llvm::orc::ThreadSafeContext context_; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/modular.cc b/paddle/cinn/backends/modular.cc new file mode 100644 index 0000000000000..e09c06b0d43ef --- /dev/null +++ b/paddle/cinn/backends/modular.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/modular.h" + +#include "cinn/ir/ir_visitor.h" + +namespace cinn { +namespace backends { + +class ModularEvaluator : public ir::IRVisitorBase { + public: + explicit ModularEvaluator(const std::map& mod_map) : mod_map_(mod_map) {} + + ModularEntry Eval(const Expr& e) { return ir::IRVisitorBase::Visit(&e); } + + ModularEntry Visit(const ir::IntImm* op) { + if (op->value < std::numeric_limits::max()) { + return ModularEntry{static_cast(op->value), 0}; + } + return ModularEntry::everything(); + } + + ModularEntry Visit(const ir::UIntImm* op) { + if (op->value < std::numeric_limits::max()) { + return ModularEntry{static_cast(op->value), 0}; + } + return ModularEntry::everything(); + } + + ModularEntry Visit(const ir::_Var_* op) { + Var var(&Reference(op)); + auto it = mod_map_.find(var); + if (it != mod_map_.end()) return it->second; + return ModularEntry::everything(); + } + + ModularEntry Visit(const ir::Add* op) { + auto a = Eval(op->a()); + auto b = Eval(op->b()); + ModularEntry ret; + ret.coeff = gcd(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base + b.base, ret.coeff); + return ret; + } + + ModularEntry Visit(const ir::Sub* op) { + auto a = Eval(op->a()); + auto b = Eval(op->b()); + + ModularEntry ret; + ret.coeff = gcd(a.coeff, b.coeff); + ret.base = BaseSimplify(a.base - b.base, ret.coeff); + return ret; + } + + ModularEntry Visit(const ir::Mul* op) { + auto a = Eval(op->a()); + auto b = Eval(op->b()); + + int pq = a.coeff * b.coeff; + int pm = a.coeff * b.base; + int qn = a.base * b.coeff; + + ModularEntry ret; + ret.coeff = gcd(pq, gcd(pm, qn)); + ret.base = BaseSimplify(a.base * b.base, ret.coeff); + return ret; + } + + ModularEntry Visit(const ir::Div* op) { + auto a = Eval(op->a()); + auto b = Eval(op->b()); + + if (b.coeff % b.base == 0) { + ModularEntry ret; + ret.coeff = a.coeff / b.base; + ret.base = 0; + return ret; + } + + return ModularEntry::everything(); + } + + static int BaseSimplify(int base, int coeff) { + if (coeff == 0) return base; + base = base % coeff; + if (base < 0) base += coeff; + return base; + } + + static int gcd(int a, int b) { + CHECK_GE(a, 0); + CHECK_GE(b, 0); + if (a < b) std::swap(a, b); + if (b == 0) return a; + + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; + } + + private: + const std::map& mod_map_; +}; + +ModularEntry ModularEntry::Add(const ModularEntry& a, const ModularEntry& b) { + ModularEntry ret; + ret.coeff = ModularEvaluator::gcd(a.coeff, b.coeff); + ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff); + return ret; +} + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/modular.h b/paddle/cinn/backends/modular.h new file mode 100644 index 0000000000000..a72bc9f922b18 --- /dev/null +++ b/paddle/cinn/backends/modular.h @@ -0,0 +1,40 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace backends { + +// borrowed from Halide and TVM. +struct ModularEntry { + int base; + int coeff; + + ModularEntry() = default; + ModularEntry(int base, int coeff) : base(base), coeff(coeff) {} + + static ModularEntry everything() { return ModularEntry{0, 1}; } + + static ModularEntry Add(const ModularEntry& a, const ModularEntry& b); +}; + +ModularEntry EvalModular(const Expr& e, const std::map& mod_map); + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/nvrtc/CMakeLists.txt b/paddle/cinn/backends/nvrtc/CMakeLists.txt new file mode 100644 index 0000000000000..a344b65ca93e4 --- /dev/null +++ b/paddle/cinn/backends/nvrtc/CMakeLists.txt @@ -0,0 +1,8 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + header_generator.cc + nvrtc_util.cc +) + +nv_test(test_nvrtc_util SRCS nvrtc_util_test.cc DEPS cinncore) diff --git a/paddle/cinn/backends/nvrtc/header_generator.cc b/paddle/cinn/backends/nvrtc/header_generator.cc new file mode 100644 index 0000000000000..85972814bcbc0 --- /dev/null +++ b/paddle/cinn/backends/nvrtc/header_generator.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/nvrtc/header_generator.h" + +#include "glog/logging.h" +#include "jitify.hpp" + +namespace cinn { +namespace backends { +namespace nvrtc { + +HeaderGeneratorBase& JitSafeHeaderGenerator::GetInstance() { + static JitSafeHeaderGenerator instance; + return instance; +} + +const size_t JitSafeHeaderGenerator::size() const { + CHECK_EQ(include_names_.size(), headers_.size()) << "Internal error in size of header files."; + return include_names_.size(); +} + +JitSafeHeaderGenerator::JitSafeHeaderGenerator() { + const auto& headers_map = ::jitify::detail::get_jitsafe_headers_map(); + for (auto& pair : headers_map) { + include_names_.emplace_back(pair.first.data()); + headers_.emplace_back(pair.second.data()); + } +} + +} // namespace nvrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/nvrtc/header_generator.h b/paddle/cinn/backends/nvrtc/header_generator.h new file mode 100644 index 0000000000000..1e6e57665857e --- /dev/null +++ b/paddle/cinn/backends/nvrtc/header_generator.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace cinn { +namespace backends { +class HeaderGeneratorBase { + public: + virtual const size_t size() const = 0; + virtual const std::vector& headers() const = 0; + virtual const std::vector& include_names() const = 0; +}; + +namespace nvrtc { + +class JitSafeHeaderGenerator : public HeaderGeneratorBase { + public: + static HeaderGeneratorBase& GetInstance(); + const size_t size() const; + const std::vector& headers() const override { return headers_; } + const std::vector& include_names() const override { return include_names_; } + + private: + JitSafeHeaderGenerator(); + std::vector headers_; + std::vector include_names_; +}; + +} // namespace nvrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util.cc b/paddle/cinn/backends/nvrtc/nvrtc_util.cc new file mode 100644 index 0000000000000..4598054701129 --- /dev/null +++ b/paddle/cinn/backends/nvrtc/nvrtc_util.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/nvrtc/nvrtc_util.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cinn/backends/cuda_util.h" +#include "cinn/backends/nvrtc/header_generator.h" +#include "cinn/common/common.h" +#include "cinn/runtime/flags.h" +#include "cinn/utils/string.h" + +DECLARE_string(cinn_nvcc_cmd_path); +DECLARE_bool(nvrtc_compile_to_cubin); + +namespace cinn { +namespace backends { +namespace nvrtc { + +std::string Compiler::operator()(const std::string& code, bool include_headers) { + if (runtime::CanUseNvccCompiler()) { + return CompileWithNvcc(code); + } + return CompileCudaSource(code, include_headers); +} + +Compiler::Compiler() { + if (FLAGS_nvrtc_compile_to_cubin) { +#if CUDA_VERSION >= 11010 + compile_to_cubin_ = true; +#endif + } + VLOG(4) << "FLAGS_nvrtc_compile_to_cubin: " << FLAGS_nvrtc_compile_to_cubin + << ", compile_to_cubin_: " << compile_to_cubin_; +} + +bool Compiler::compile_to_cubin() { return compile_to_cubin_; } + +std::vector Compiler::FindCUDAIncludePaths() { + const std::string delimiter = "/"; + std::string cuda_include_path; + const char* cuda_path_env = std::getenv("CUDA_PATH"); + if (cuda_path_env != nullptr) { + cuda_include_path += cuda_path_env; + cuda_include_path += delimiter + "include"; + return {cuda_include_path}; + } + +#if defined(__linux__) + struct stat st; + cuda_include_path = "/usr/local/cuda/include"; + if (stat(cuda_include_path.c_str(), &st) == 0) { + return {cuda_include_path}; + } +#endif + LOG(FATAL) << "Cannot find cuda include path." + << "CUDA_PATH is not set or CUDA is not installed in the default installation path." + << "In other than linux, it is necessary to set CUDA_PATH."; + return {cuda_include_path}; +} + +std::vector Compiler::FindCINNRuntimeIncludePaths() { return {Context::Global().runtime_include_dir()}; } + +std::string Compiler::CompileCudaSource(const std::string& code, bool include_headers) { + const auto& header_gen = JitSafeHeaderGenerator::GetInstance(); + std::vector compile_options; + std::vector param_cstrings{}; + nvrtcProgram prog; + std::string cc = "30"; + int major, minor; + cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0); + cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0); + + if (e1 == cudaSuccess && e2 == cudaSuccess) { + cc = std::to_string(major) + std::to_string(minor); + } else { + LOG(WARNING) << "cannot detect compute capability from your device, " + << "fall back to compute_30."; + } + if (compile_to_cubin_) { + compile_options.push_back("-arch=sm_" + cc); + } else { + compile_options.push_back("-arch=compute_" + cc); + } + compile_options.push_back("-std=c++14"); + compile_options.push_back("-default-device"); + + if (include_headers) { // prepare include headers + auto cuda_headers = FindCUDAIncludePaths(); + auto cinn_headers = FindCINNRuntimeIncludePaths(); + std::vector include_paths; + for (auto& header : cuda_headers) { + include_paths.push_back("--include-path=" + header); + } + for (auto& header : cinn_headers) { + include_paths.push_back("--include-path=" + header); + } + compile_options.insert(std::end(compile_options), include_paths.begin(), include_paths.end()); + } + + for (const auto& option : compile_options) { + param_cstrings.push_back(option.c_str()); + } + VLOG(3) << "compile options: " << utils::Join(compile_options, " "); + NVRTC_CALL(nvrtcCreateProgram( + &prog, code.c_str(), nullptr, header_gen.size(), header_gen.headers().data(), header_gen.include_names().data())); + nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + + { // get log + size_t log_size; + NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); + std::string log; + log.resize(log_size); + NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); + CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; + } + + size_t size; + std::string data; + if (compile_to_cubin_) { + NVRTC_CALL(nvrtcGetCUBINSize(prog, &size)); + data.resize(size); + NVRTC_CALL(nvrtcGetCUBIN(prog, &data[0])); + } else { + NVRTC_CALL(nvrtcGetPTXSize(prog, &size)); + data.resize(size); + NVRTC_CALL(nvrtcGetPTX(prog, &data[0])); + } + + NVRTC_CALL(nvrtcDestroyProgram(&prog)); + return data; +} + +std::string Compiler::CompileWithNvcc(const std::string& cuda_c) { + // read dir source + std::string dir = "./source"; + if (access(dir.c_str(), 0) == -1) { + CHECK(mkdir(dir.c_str(), 7) != -1) << "Fail to mkdir " << dir; + } + + // get unqiue prefix name + prefix_name_ = dir + "/" + common::UniqName("rtc_tmp"); + + auto cuda_c_file = prefix_name_ + ".cu"; + std::ofstream ofs(cuda_c_file, std::ios::out); + CHECK(ofs.is_open()) << "Fail to open file " << cuda_c_file; + ofs << cuda_c; + ofs.close(); + + CompileToPtx(); + CompileToCubin(); + + return prefix_name_ + ".cubin"; +} + +// std::string Compiler::GetPtx() { return ReadFile(prefix_name_ + ".ptx", std::ios::in); } + +void Compiler::CompileToPtx() { + auto include_dir = common::Context::Global().runtime_include_dir(); + std::string include_dir_str = ""; + for (auto dir : include_dir) { + if (include_dir_str.empty()) { + include_dir_str = dir; + } else { + include_dir_str += ":" + dir; + } + } + + std::string options = std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + + std::string(":$PATH && nvcc -std=c++14 --ptx -O3 -I ") + include_dir_str; + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".ptx"; + options += " " + prefix_name_ + ".cu"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +void Compiler::CompileToCubin() { + std::string options = + std::string("export PATH=") + FLAGS_cinn_nvcc_cmd_path + std::string(":$PATH && nvcc --cubin -O3"); + options += " -arch=" + GetDeviceArch(); + options += " -o " + prefix_name_ + ".cubin"; + options += " " + prefix_name_ + ".ptx"; + + VLOG(2) << "Nvcc Compile Options : " << options; + CHECK(system(options.c_str()) == 0) << options; +} + +std::string Compiler::GetDeviceArch() { + int major = 0, minor = 0; + if (cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0) == cudaSuccess && + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0) == cudaSuccess) { + return "sm_" + std::to_string(major) + std::to_string(minor); + } else { + LOG(WARNING) << "cannot detect compute capability from your device, " + << "fall back to compute_30."; + return "sm_30"; + } +} + +std::string Compiler::ReadFile(const std::string& file_name, std::ios_base::openmode mode) { + // open cubin file + std::ifstream ifs(file_name, mode); + CHECK(ifs.is_open()) << "Fail to open file " << file_name; + ifs.seekg(std::ios::end); + auto len = ifs.tellg(); + ifs.seekg(0); + + // read cubin file + std::string file_data(len, ' '); + ifs.read(&file_data[0], len); + ifs.close(); + return std::move(file_data); +} + +} // namespace nvrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util.h b/paddle/cinn/backends/nvrtc/nvrtc_util.h new file mode 100644 index 0000000000000..b13c24c550a63 --- /dev/null +++ b/paddle/cinn/backends/nvrtc/nvrtc_util.h @@ -0,0 +1,92 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef CINN_WITH_CUDA +#if defined(__linux__) +#include +#endif +#include + +#include +#include + +namespace cinn { +namespace backends { +namespace nvrtc { + +/** + * An helper class to call NVRTC. Input CUDA device source code, get PTX string. + */ +class Compiler { + public: + Compiler(); + + /** + * Compile the \p code and get PTX string. + * @param code The CUDA source code. + * @param include_headers Whether to include the headers of CUDA and CINN runtime modules. + * @return Compiled PTX code string. + */ + std::string operator()(const std::string& code, bool include_headers = true); + + /** Compile into cubin or not + * @return Compile into cubin or not. + */ + bool compile_to_cubin(); + + private: + /** + * Get the directories of CUDA's header files. + * @return list of header file directories. + */ + std::vector FindCUDAIncludePaths(); + + /** + * Get the directories of CINN runtime's header files. + * @return list of header file directories. + */ + std::vector FindCINNRuntimeIncludePaths(); + + /** + * Compile CUDA source code and get PTX or CUBIN. + * @param code source code string. + * @return PTX or CUBIN string. + */ + std::string CompileCudaSource(const std::string& code, bool include_headers); + + /** + * whether to compile the source code into cubin, only works with cuda version > 11.1 + */ + bool compile_to_cubin_{false}; + + // compile with nvcc + std::string CompileWithNvcc(const std::string&); + + // compile to ptx + void CompileToPtx(); + // compile to cubin + void CompileToCubin(); + std::string GetDeviceArch(); + + std::string ReadFile(const std::string&, std::ios_base::openmode); + + std::string prefix_name_{""}; +}; + +} // namespace nvrtc +} // namespace backends +} // namespace cinn + +#endif // CINN_WITH_CUDA diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util_test.cc b/paddle/cinn/backends/nvrtc/nvrtc_util_test.cc new file mode 100644 index 0000000000000..9a21934130086 --- /dev/null +++ b/paddle/cinn/backends/nvrtc/nvrtc_util_test.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/nvrtc/nvrtc_util.h" + +#include + +namespace cinn { +namespace backends { +namespace nvrtc { + +TEST(Compiler, basic) { + Compiler compiler; + + std::string source_code = R"ROC( +extern "C" __global__ +void saxpy(float a, float *x, float *y, float *out, size_t n) +{ + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + out[tid] = a * x[tid] + y[tid]; + } +} +)ROC"; + + auto ptx = compiler(source_code); + + LOG(INFO) << "ptx:\n" << ptx; +} + +TEST(Compiler, float16) { + Compiler compiler; + + std::string source_code = R"( +#include +#define CINN_WITH_CUDA +#include "float16.h" +using cinn::common::float16; + +extern "C" __global__ +void cast_fp32_to_fp16_cuda_kernel(const float* input, const int num, float16* out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num) { + out[idx] = float16(input[idx]); + } +} +)"; + + auto ptx = compiler(source_code); + + LOG(INFO) << "ptx:\n" << ptx; +} + +TEST(Compiler, bfloat16) { + Compiler compiler; + + std::string source_code = R"( +#include +#define CINN_WITH_CUDA +#include "bfloat16.h" +using cinn::common::bfloat16; + +extern "C" __global__ +void cast_fp32_to_bf16_cuda_kernel(const float* input, const int num, bfloat16* out) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < num) { + out[idx] = bfloat16(input[idx]); + } +} +)"; + + auto ptx = compiler(source_code); + + LOG(INFO) << "ptx:\n" << ptx; +} + +} // namespace nvrtc +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/outputs.cc b/paddle/cinn/backends/outputs.cc new file mode 100644 index 0000000000000..65d4cc76899fe --- /dev/null +++ b/paddle/cinn/backends/outputs.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/backends/outputs.h" + +namespace cinn { +namespace lang {} // namespace lang + +backends::Outputs backends::Outputs::object(const std::string &name) const { + Outputs updated = *this; + updated.object_name = name; + return updated; +} + +backends::Outputs backends::Outputs::bitcode(const std::string &name) const { + Outputs updated = *this; + updated.bitcode_name = name; + return updated; +} + +backends::Outputs backends::Outputs::c_header(const std::string &name) const { + Outputs updated = *this; + updated.c_header_name = name; + return updated; +} + +backends::Outputs backends::Outputs::c_source(const std::string &name) const { + Outputs updated = *this; + updated.c_source_name = name; + return updated; +} + +backends::Outputs backends::Outputs::cuda_source(const std::string &name) const { + Outputs updated = *this; + updated.cuda_source_name = name; + return updated; +} + +} // namespace cinn diff --git a/paddle/cinn/backends/outputs.h b/paddle/cinn/backends/outputs.h new file mode 100644 index 0000000000000..45c4c9e1418e7 --- /dev/null +++ b/paddle/cinn/backends/outputs.h @@ -0,0 +1,52 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace cinn { +namespace backends { + +/** + * A struct specifying a collection of outputs. + */ +struct Outputs { + //! The name of the emitted object file. Empty if no object file is desired. + std::string object_name; + + //! The name of the emitted llvm bitcode. Empty if no bitcode file is desired. + std::string bitcode_name; + + //! The name of the emitted C header file. + std::string c_header_name; + + //! The name of the emitted C source file. + std::string c_source_name; + + //! The name of the emitted CUDA source file. + std::string cuda_source_name; + + Outputs object(const std::string& name) const; + + Outputs bitcode(const std::string& name) const; + + Outputs c_header(const std::string& name) const; + + Outputs c_source(const std::string& name) const; + + Outputs cuda_source(const std::string& name) const; +}; + +} // namespace backends +} // namespace cinn diff --git a/paddle/cinn/backends/raw_cuda_code_test.cu b/paddle/cinn/backends/raw_cuda_code_test.cu new file mode 100644 index 0000000000000..765ef5bd986bb --- /dev/null +++ b/paddle/cinn/backends/raw_cuda_code_test.cu @@ -0,0 +1,54 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cinn/backends/cuda_util.h" +#include "cinn/utils/timer.h" + +__global__ void elementwise_add_kernel(const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C) { + if ((blockIdx.x < 1024)) { + { + if ((threadIdx.x < 1024)) { + { + C[((1024 * blockIdx.x) + threadIdx.x)] = + (A[((1024 * blockIdx.x) + threadIdx.x)] + B[((1024 * blockIdx.x) + threadIdx.x)]); + } + } + } + } +} + +TEST(raw_cuda, basic) { + const int M = 1024; + const int N = 1024; + // allocate CUDA buffer + float *Ag, *Bg, *Cg; + const int num_bytes = M * N * sizeof(float); + cudaMalloc(&Ag, num_bytes); + cudaMalloc(&Bg, num_bytes); + cudaMalloc(&Cg, num_bytes); + + cinn::utils::Timer timer; + timer.Start(); + for (int i = 0; i < 1000; i++) { + elementwise_add_kernel<<<1024, 1024>>>(Ag, Bg, Cg); + } + CUDA_CALL(cudaDeviceSynchronize()); + float latency = timer.Stop(); + LOG(INFO) << "latency: " << latency / 1000; +} diff --git a/paddle/cinn/cinn.h b/paddle/cinn/cinn.h new file mode 100644 index 0000000000000..41ce22a7b54ba --- /dev/null +++ b/paddle/cinn/cinn.h @@ -0,0 +1,56 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * This file exposes some internal APIs to global cinn namespace to make usage more friendly. + */ +#pragma once +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" +#include "cinn/common/common.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/placeholder.h" +#include "cinn/optim/optimize.h" + +namespace cinn { + +using backends::CodeGenC; +using backends::CodeGenCX86; +using backends::Outputs; +using ir::Module; +using ir::Var; +using lang::Buffer; +using lang::CallExtern; +using lang::CallLowered; +using lang::Compute; +using lang::Lower; +using lang::Placeholder; +using lang::ReduceAll; +using lang::ReduceAny; +using lang::ReduceMax; +using lang::ReduceMin; +using lang::ReduceMul; +using lang::ReduceSum; +using optim::Optimize; +using poly::CreateStages; + +using lang::logic_and; +using lang::logic_or; + +using common::Target; + +} // namespace cinn diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt new file mode 100644 index 0000000000000..f45e2812960a0 --- /dev/null +++ b/paddle/cinn/common/CMakeLists.txt @@ -0,0 +1,36 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + shared.cc + cinn_value.cc + type.cc + target.cc + object.cc + debug_manager.cc + info_registry.cc + graph_utils.cc + context.cc + axis.cc + ir_util.cc + test_helper.cc + # cuda_test_helper.cc + arithmatic.cc + cas.cc + union_find.cc + python_interpreter_guard.cc + ) + + message(STATUS "srcs: ${cinnapi_src}") + +cc_test(test_cinn_value SRCS cinn_value_test.cc DEPS cinncore) +cc_test(test_shared SRCS shared_test.cc DEPS cinncore) +cc_test(test_graph_utils SRCS graph_utils_test.cc DEPS cinncore) +cc_test(test_arithmatic SRCS arithmatic_test.cc DEPS cinncore) +cc_test(test_cas SRCS cas_test.cc DEPS cinncore) +cc_test(test_type SRCS type_test.cc DEPS cinncore) +cc_test(test_axis SRCS axis_test.cc DEPS cinncore) + +cc_test(test_fp16_bf16_host SRCS float16_bfloat16_host_test.cc DEPS gtest glog) +if (WITH_CUDA) +nv_test(test_fp16_bf16_cuda SRCS float16_bfloat16_cuda_test.cu DEPS gtest glog) +endif() diff --git a/paddle/cinn/common/arithmatic.cc b/paddle/cinn/common/arithmatic.cc new file mode 100644 index 0000000000000..8fd8bb6f6ec50 --- /dev/null +++ b/paddle/cinn/common/arithmatic.cc @@ -0,0 +1,310 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/common/arithmatic.h" + +#include +#include +#include +#include +#include + +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace common { + +using utils::GetStreamCnt; +using utils::Join; +using utils::Replace; +using utils::Split; +using namespace ir; // NOLINT + +#ifdef As +#undef As +#endif + +std::string ExprToGinacConverter::Repr(const ir::Expr& expr) { + auto* load_n = expr.As(); + auto* var_n = expr.As<_Var_>(); + auto* broadcast_n = expr.As(); + auto* mod_n = expr.As(); + auto* min_n = expr.As(); + auto* max_n = expr.As(); + auto* div_n = expr.As
(); + auto* frac_n = expr.As(); + if (load_n || broadcast_n || mod_n || min_n || max_n || div_n || frac_n) { + std::string repr = GetStreamCnt(expr); + Replace(&repr, "[", "lsq_"); + Replace(&repr, "]", "_rsq"); + Replace(&repr, "(", "lb_"); + Replace(&repr, ")", "_rb"); + Replace(&repr, "+", "_add_"); + Replace(&repr, "-", "_sub_"); + Replace(&repr, ":", "_ref_"); + Replace(&repr, "*", "_mul_"); + Replace(&repr, "/", "_div_"); + // remove the spaces + auto fields = utils::Split(repr, " "); + repr = utils::Join(fields, "_"); + return repr; + } else if (var_n) { + return utils::GetStreamCnt(expr); + } + return ""; +} + +void ExprToGinacConverter::RecordExpr(const ir::Expr& expr) { repr_to_expr_[Repr(expr)] = expr; } + +GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) { + auto* load_n = expr.As(); + auto* var_n = expr.As<_Var_>(); + auto* int_n = expr.As(); + auto* float_n = expr.As(); + auto* add_n = expr.As(); + auto* sub_n = expr.As(); + auto* mul_n = expr.As(); + auto* div_n = expr.As
(); + auto* minus_n = expr.As(); + auto* broadcast_n = expr.As(); + auto* mod_n = expr.As(); + auto* frac_n = expr.As(); + auto* min_n = expr.As(); + auto* max_n = expr.As(); + + bool is_integer_math = expr.type().is_int(); + + bool is_invalid_arith = load_n || var_n || broadcast_n || mod_n || min_n || max_n; + if (is_integer_math) + is_invalid_arith = is_invalid_arith || div_n || frac_n; // GiNac can't deal with integer division. + + if (is_invalid_arith) { + RecordExpr(expr); + std::string repr = Repr(expr); + return CreateGinacSymbol(repr); + } else if (int_n) { + return int_n->value; + } else if (float_n) { + return float_n->value; + } else if (add_n) { + auto a = BuildHelper(add_n->a()); + auto b = BuildHelper(add_n->b()); + return (a + b) * 1; + } else if (sub_n) { + return (BuildHelper(sub_n->a()) - BuildHelper(sub_n->b())); + } else if (mul_n) { + return (BuildHelper(mul_n->a()) * BuildHelper(mul_n->b())); + } else if (div_n) { + return (BuildHelper(div_n->a()) / BuildHelper(div_n->b())); + } else if (frac_n) { + return (BuildHelper(frac_n->a()) / BuildHelper(frac_n->b())); + } else if (minus_n) { + return -BuildHelper(minus_n->v()); + } else { + CINN_NOT_IMPLEMENTED + } +} + +GiNaC::ex ExprToGinacConverter::operator()(Expr expr) { + // TODO(Superjomn) Replace this with common::IsPureMath( + auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) { + return n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As() || // + n->As { + Expr condition; + Expr true_value; + Expr false_value; + + Select(Expr condition, Expr true_value, Expr false_value) + : ExprNode(condition, true_value, false_value); + return Expr(node); + } + + Type type() const override { + CHECK_EQ(true_value.type(), false_value.type()); + return true_value.type(); + } + + void Verify() const override; + + std::vector expr_fields() override { return {&condition, &true_value, &false_value}; } + std::vector expr_fields() const override { return {&condition, &true_value, &false_value}; } + + static const IrNodeTy _node_type_ = IrNodeTy::Select; +}; + +struct LoadStoreAddrMnger { + Expr tensor; // Should be a tensor or a scalar. + //! Tell whether the address is a tensor. + bool is_addr_tensor() const; + //! Tell whether the address is a scalar. + bool is_addr_scalar() const; +}; + +/** + * Load the value from a buffer (as an array). + */ +struct Load : public ExprNode, public LoadStoreAddrMnger { + std::vector indices; + //! The abstract offset. + Expr index() const; + + static Expr Make(Expr tensor, const std::vector& indices); + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + void Verify() const override; + + const std::string& name() const; + + Type type() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Load; +}; + +/** + * Store a `value` to the buffer at a given `index`. + */ +struct Store : public ExprNode, public LoadStoreAddrMnger { + Expr value; + std::vector indices; + + static Expr Make(Expr tensor, Expr value, const std::vector& indices); + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + void Verify() const override; + + const std::string& name() const; + + Type type() const override; + Expr index() const; + + static const IrNodeTy _node_type_ = IrNodeTy::Store; +}; + +/** + * Allocate a buffer with the given type and size. The buffer lives for at most the duration of the body statement, + * within which it is freed. + */ +struct Alloc : public ExprNode { + //! The destination of the allocation, this might be a buffer or a variable. + Expr destination; + //! Dimensions of this buffer (as a multi-dimensional array). + std::vector extents; + // NOTE the condition might be undefined, that means always true. + Expr condition; + // NOTE the body might be undefined, that means no specific logic other than default. + Expr body; + + Alloc() : ExprNode(Type()) {} + + static Expr Make(Expr dest, Type type, const std::vector& extents, Expr condition, Expr body); + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + void Verify() const override; + + int32_t ConstantAllocationSize() const; + static int32_t ConstantAllocationSize(const std::vector& extents); + + static const IrNodeTy _node_type_ = IrNodeTy::Alloc; +}; + +/** + * Free the resources associated with the given buffer. + */ +struct Free : public ExprNode { + Expr destination; + + Free() : ExprNode(Type()) {} + + static Expr Make(Expr dest); + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Free; +}; + +struct IfThenElse : public ExprNode { + Expr condition; + Expr true_case; + Expr false_case; + + IfThenElse(Expr condition, Expr true_case, Expr false_case); + + static Expr Make(Expr condition, Expr true_case, Expr false_case = Expr()); + + void Verify() const override { + CHECK(condition.defined()); + CHECK(true_case.defined()); + CHECK_EQ(condition.type(), type_of()); + } + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::IfThenElse; +}; + +enum class ForType : int { + Serial = 0, //! Serial execution. + Parallel = 1, //! Parallel execution. + Vectorized = 1 << 1, //! Vector SIMD loop annotation. + Unrolled = 1 << 2, //! Unroll annotation. + GPUThread = 1 << 3, //! GPU Thread. + GPUBlock = 1 << 4, //! GPU Block. + GPULane = 1 << 5, //! GPU Lane. + Default = 1 << 6, +}; + +struct VectorizeInfo { + VectorizeInfo() = default; + VectorizeInfo(int level, int factor) : level(level), factor(factor) {} + + int level{-1}; + int factor{-1}; + + inline void set(int level, int factor) { + this->level = level; + this->factor = factor; + } + inline bool valid() const { return level >= 0 && factor > 0; } +}; + +struct BindInfo { + BindInfo() = default; + BindInfo(const ForType& for_type, const int& offset, const DeviceAPI& device) + : for_type(for_type), offset(offset), device(device) {} + + ForType for_type{ForType::Default}; + int offset{-1}; + DeviceAPI device{DeviceAPI::UNK}; + + inline void set(const ForType& for_type, const int& offset, const DeviceAPI& device) { + this->for_type = for_type; + this->offset = offset; + this->device = device; + } + // offset should be 0-2, should correspond to the thread of x, y, z + inline bool valid() const { + return offset >= 0 && offset < 3 && (for_type == ForType::GPUThread || for_type == ForType::GPUBlock); + } +}; + +struct ForBase { + ForType for_type() const { return for_type_; } + void set_for_type(ForType x) { for_type_ = x; } + + void set_vectorize_info(const VectorizeInfo& x) { + if (x.valid()) set_vectorized(); + vectorize_info_ = x; + } + void set_bind_info(const BindInfo& x) { + if (x.valid()) set_binded(x.for_type); + bind_info_ = x; + } + const VectorizeInfo& vectorize_info() const { return vectorize_info_; } + const BindInfo& bind_info() const { return bind_info_; } + + void reset_vectorize_info() { + set_vectorized(false); + vectorize_info_.factor = -1; + vectorize_info_.level = -1; + } + void reset_bind_info() { + set_binded(bind_info_.for_type, false); + bind_info_.offset = -1; + bind_info_.device = DeviceAPI::UNK; + } + + void set_serial() { for_type_ = ForType::Serial; } + + void set_unrolled(bool x = true) { + if (x) + set_for_type_flag(ForType::Unrolled); + else + unset_for_type_flag(ForType::Unrolled); + } + void set_vectorized(bool x = true) { + if (x) + set_for_type_flag(ForType::Vectorized); + else + unset_for_type_flag(ForType::Vectorized); + } + void set_parallel(bool x = true) { + if (x) + set_for_type_flag(ForType::Parallel); + else + unset_for_type_flag(ForType::Parallel); + } + void set_binded(ForType for_type, bool x = true) { + if (x) + set_for_type_flag(for_type); + else + unset_for_type_flag(for_type); + } + + inline bool is_serial() const { return for_type_ == ForType::Serial; } + inline bool is_default() const { return for_type_ == ForType::Default; } + inline bool is_unrolled() const { return tell_for_type_flag(ForType::Unrolled); } + inline bool is_vectorized() const { return tell_for_type_flag(ForType::Vectorized); } + inline bool is_parallel() const { return tell_for_type_flag(ForType::Parallel); } + inline bool is_binded() const { + return tell_for_type_flag(ForType::GPUBlock) || tell_for_type_flag(ForType::GPUThread); + } + inline bool is_gpu_block_binded() const { return tell_for_type_flag(ForType::GPUBlock); } + inline bool is_gpu_thread_binded() const { return tell_for_type_flag(ForType::GPUThread); } + + private: + inline void set_for_type_flag(ForType type) { *reinterpret_cast(&for_type_) |= static_cast(type); } + inline void unset_for_type_flag(ForType type) { *reinterpret_cast(&for_type_) &= ~static_cast(type); } + inline bool tell_for_type_flag(ForType type) const { return static_cast(for_type_) & static_cast(type); } + + ForType for_type_{ForType::Serial}; + VectorizeInfo vectorize_info_; + BindInfo bind_info_; +}; + +/// LLVM loop unroll metadata infomation +struct LLVMForLoopMeta { + enum UnrollMode { DefaultUnroll, FullyUnroll, NoUnroll }; + + UnrollMode unroll_mode{DefaultUnroll}; + bool vectorization{true}; +}; + +struct For : public ExprNode, public ForBase { + //! The loop variable. + Var loop_var; + //! The minimum value of the iteration. + Expr min; + //! The extent of the iteration. + Expr extent; + + Expr body; + + DeviceAPI device_api; + + LLVMForLoopMeta metadata; + + static Expr Make(Var loop_var, + Expr min, + Expr extent, + ForType for_type, + DeviceAPI device_api, + Expr body, + VectorizeInfo vector_info = VectorizeInfo(), + BindInfo bind_info = BindInfo()); + + void Verify() const override; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::For; +}; + +//! Polyhedral forloop, which condition is more complex than the normal `For`. +struct PolyFor : public ExprNode, public ForBase { + //! The iterator variable. + Var iterator; + // Initial value of the iterator. + Expr init; + //! The condition to continue the loop. + Expr condition; + //! Increase the iterator. + Expr inc; + //! The forloop body. + Expr body; + + DeviceAPI device_api; + + PolyFor() : ExprNode(Type()) {} + + Expr ExtractExtent() const; + + static Expr Make(Var iterator, + Expr init_val, + Expr condition, + Expr inc, + ForType for_type, + DeviceAPI device_api, + Expr body, + VectorizeInfo vector_info = VectorizeInfo(), + BindInfo bind_info = BindInfo()); + + void Verify() const override; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::PolyFor; +}; + +//! A linear ramp node. +struct Ramp : public ExprNode { + Expr base, stride; + int lanes; + + static Expr Make(Expr base, Expr stride, int lanes); + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Ramp; +}; + +//! A vector with `lanes` elements and all of them are `value`. +struct Broadcast : public ExprNode { + Expr value; + int lanes; + + static Expr Make(Expr value, int lanes); + + Type type() const override; + + void Verify() const override; + + std::vector expr_fields() override { return {&value}; } + std::vector expr_fields() const override { return {&value}; } + + static const IrNodeTy _node_type_ = IrNodeTy::Broadcast; +}; + +struct FracOp : public BinaryOpNode { + FracOp() { operands().resize(2); } + + static Expr Make(Expr n, Expr d); + + bool is_constant() const { return a().is_constant() && b().is_constant(); } + + double get_constant() const { + CHECK(is_constant()); + CHECK_NE(b().get_constant(), 0.f); + return a().get_constant() / b().get_constant(); + } + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::FracOp; + + using ExprNode::operands; +}; + +struct Product : public ExprNode { + static Expr Make(const std::vector& vs); + + using ExprNode::operand; + + Type type() const override { return operands().front().type(); } + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Product; +}; + +struct Sum : public ExprNode { + static Expr Make(const std::vector& vs); + + using ExprNode::operand; + + Type type() const override { return operands().front().type(); } + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Sum; +}; + +struct Block : public ExprNode { + std::vector stmts; + + Block() : ExprNode(Type()) {} + + static Expr Make(const std::vector& stmts); + + void Verify() const override; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::Block; +}; + +// ScheduleBlock is the unit of schedule IR which represents tensor's computation +struct ScheduleBlock : public ExprNode { + std::vector iter_vars; + // BufferRange(s) which is read in this schedule block, it is used to + // analyze, not a real computation expression. Must be AST DFS order. + std::vector read_buffers; + // BufferRange(s) which is written in this schedule block, it is used to + // analyze, not a real computation expression. Must be AST DFS order. + std::vector write_buffers; + // Additional attributes about this schedulable block, + // which take some auxiliary hints for future transformations. + std::map attrs; + std::string name; + Expr body; + + static Expr Make(const std::vector& iter_vars, + const std::vector& read_buffers, + const std::vector& write_buffers, + const std::string& name, + Expr body); + + void Verify() const override; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::ScheduleBlock; +}; + +// ScheduleBlockRealize is used to execute ScheduleBlock with the binding iter_values +struct ScheduleBlockRealize : public ExprNode { + // values of the iter_vars + std::vector iter_values; + Expr schedule_block; + + static Expr Make(const std::vector& iter_values, const Expr& schedule_block); + + void Verify() const override; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::ScheduleBlockRealize; +}; + +/** + * Content of a module. + */ +struct _Module_ : public ExprNode<_Module_> { + std::string name; + Target target; + std::vector buffers; + std::vector functions; + std::vector submodules; + + static ir::Module Make(const std::string& name, Target target); + + void Verify() const override {} + + static const IrNodeTy _node_type_ = IrNodeTy::_Module_; +}; + +/** + * \brief PrimitiveNode holds the contept of Primitive in CINN. + * A Primitive is a basic Call to some Expr function, it is introduced to create several level of coarsed-grained IR + * nodes for better IR optimization and hardware adaption. + */ +struct PrimitiveNode : public ExprNode { + std::string name; + //! the inputs of the PrimitiveNode, the vector> can hold variadic arguments. + std::vector> arguments; + //! the attribute of this PrimitiveNode. + std::map attrs; + + static Expr Make(const std::string& name, const std::map& attrs); + + void Verify() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::PrimitiveNode; +}; + +// possiable keys of attributes in ir nodes with are listed in the following namespace +namespace attr { + +// max permitted steps for auto_unroll, used in unroll_loop pass +constexpr const char* auto_unroll_max_step = "auto_unroll_max_step"; +// record the extra loop built during ComputeAt, used for calculate the size of temp buffer in post-processing +constexpr const char* compute_at_extra_var = "compute_at_extra_var"; +// record the extra loop built during ReverseComputeAt, used for calculate the size of temp buffer in post-processing +constexpr const char* reverse_compute_at_extra_var = "reverse_compute_at_extra_var"; +// record the cooperative process info, used in post schedule rule(CooperativeProcess) +constexpr const char* cooperative_process = "cooperative_process"; + +} // namespace attr + +} // namespace ir + +// Expose the following to cinn namespace for easier usage. +// @{ +using ir::Expr; +using ir::Var; +// @} + +} // namespace cinn diff --git a/paddle/cinn/ir/ir_base.cc b/paddle/cinn/ir/ir_base.cc new file mode 100644 index 0000000000000..19c8004fd2bf4 --- /dev/null +++ b/paddle/cinn/ir/ir_base.cc @@ -0,0 +1,231 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_base.h" + +#include "cinn/common/cinn_value.h" +#include "cinn/common/common.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/ir/module.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace ir { + +using cinn::common::bfloat16; +using cinn::common::float16; + +//! Implementations for Ir Expr Nodes. +// @{ +#define __m(t__) \ + template <> \ + void ExprNode::Accept(cinn::ir::IRVisitor *v) const { \ + v->Visit(const_self()); \ + } +#undef __m +// @} + +std::ostream &operator<<(std::ostream &os, IrNodeTy type) { + switch (type) { +#define __m(t__) \ + case IrNodeTy::t__: \ + os << ""; \ + break; + + NODETY_FORALL(__m) +#undef __m + + default: + LOG(FATAL) << "unknown IrNodeTy found"; + } + + return os; +} + +Expr Zero(const Type &type) { + if (type.is_bfloat16()) return Expr(bfloat16(0.f)); + if (type.is_float16()) return Expr(float16(0.f)); + if (type.is_float(32)) return Expr(0.f); + if (type.is_float(64)) return Expr(double(0.)); // NOLINT + + if (type.is_bool()) return Expr(false); + + if (type.is_int(8)) return Expr(int8_t(0)); + if (type.is_int(16)) return Expr(int16_t(0)); + if (type.is_int(32)) return Expr(int32_t(0)); + if (type.is_int(64)) return Expr(int64_t(0)); + + if (type.is_uint(8)) return Expr(uint8_t(0)); + if (type.is_uint(16)) return Expr(uint16_t(0)); + if (type.is_uint(32)) return Expr(uint32_t(0)); + if (type.is_uint(64)) return Expr(uint64_t(0)); + CINN_NOT_IMPLEMENTED + return Expr(); +} + +Expr One(const Type &type) { + if (type.is_bfloat16()) return Expr(bfloat16(1.f)); + if (type.is_float16()) return Expr(float16(1.f)); + if (type.is_float(32)) return Expr(1.f); + if (type.is_float(64)) return Expr(double(1.)); // NOLINT + + if (type.is_bool()) return Expr(true); + + if (type.is_int(8)) return Expr(int8_t(1)); + if (type.is_int(16)) return Expr(int16_t(1)); + if (type.is_int(32)) return Expr(int32_t(1)); + if (type.is_int(64)) return Expr(int64_t(1)); + + if (type.is_uint(8)) return Expr(uint8_t(1)); + if (type.is_uint(16)) return Expr(uint16_t(1)); + if (type.is_uint(32)) return Expr(uint32_t(1)); + if (type.is_uint(64)) return Expr(uint64_t(1)); + CINN_NOT_IMPLEMENTED + return Expr(); +} + +Expr::Expr(const Var &var) { *static_cast(this) = *static_cast(&var); } +bool Expr::as_bool() const { + CHECK(type().is_uint(1)); + return As()->value; +} + +int8_t Expr::as_int8() const { + CHECK(type().is_int(8)); + return As()->value; +} +int16_t Expr::as_int16() const { + CHECK(type().is_int(16)); + return As()->value; +} +int32_t Expr::as_int32() const { + CHECK(type().is_int(32)); + return As()->value; +} +int64_t Expr::as_int64() const { + CHECK(type().is_int(64)); + return As()->value; +} + +uint8_t Expr::as_uint8() const { + CHECK(type().is_uint(8)); + return As()->value; +} +uint16_t Expr::as_uint16() const { + CHECK(type().is_uint(16)); + return As()->value; +} +uint32_t Expr::as_uint32() const { + CHECK(type().is_uint(32)); + return As()->value; +} +uint64_t Expr::as_uint64() const { + CHECK(type().is_uint(64)); + return As()->value; +} + +bfloat16 Expr::as_bfloat16() const { + CHECK(type().is_bfloat16()); + return bfloat16(As()->value); +} +float16 Expr::as_float16() const { + CHECK(type().is_float16()); + return float16(As()->value); +} +float Expr::as_float() const { + CHECK(type().is_float(32)); + return As()->value; +} +double Expr::as_double() const { + CHECK(type().is_float(64)); + return As()->value; +} + +Expr &Expr::operator=(const Expr &other) { + *static_cast(this) = *static_cast(&other); + return *this; +} + +Expr::operator Var() { + auto *x = As(); + CHECK(x); + return ir::Var(x); +} + +bool Expr::is_constant() const { return As() || As() || As(); } + +double Expr::get_constant() const { + CHECK(is_constant()) << *this << " is not constant! Please check."; + auto *vi = As(); + auto *vf = As(); + if (vi) return vi->value; + return vf->value; +} + +bool Expr::is_var() const { return As<_Var_>(); } + +_Buffer_ *Expr::as_buffer() { return As<_Buffer_>(); } +const _Buffer_ *Expr::as_buffer() const { return As<_Buffer_>(); } +Buffer Expr::as_buffer_ref() const { return Buffer(&Reference(as_buffer())); } + +_LoweredFunc_ *Expr::as_lowered_func() { return As<_LoweredFunc_>(); } +const _LoweredFunc_ *Expr::as_lowered_func() const { return As<_LoweredFunc_>(); } + +_Module_ *Expr::as_module() { return As<_Module_>(); } +const _Module_ *Expr::as_module() const { return As<_Module_>(); } +ir::Module Expr::as_module_ref() const { + auto *module = as_module(); + CHECK(module); // Need check here? + // TODO(Superjomn) remove the Reference here. + return ir::Module(&Reference(module)); +} + +LoweredFunc Expr::as_lowered_func_ref() const { + auto *function = as_lowered_func(); + CHECK(function); + return LoweredFunc(&Reference(function)); +} + +_Tensor_ *Expr::as_tensor() { return As<_Tensor_>(); } +const _Tensor_ *Expr::as_tensor() const { return As<_Tensor_>(); } +ir::Tensor Expr::as_tensor_ref() const { return ir::Tensor(&Reference(as_tensor())); } + +_Var_ *Expr::as_var() { return As<_Var_>(); } +const _Var_ *Expr::as_var() const { return As<_Var_>(); } +Var Expr::as_var_ref() const { return Var(&Reference(as_var())); } + +bool Expr::is_cmp() const { + switch (node_type()) { + case ir::IrNodeTy::LE: + case ir::IrNodeTy::LT: + case ir::IrNodeTy::EQ: + case ir::IrNodeTy::NE: + case ir::IrNodeTy::GT: + case ir::IrNodeTy::GE: + return true; + default: + return false; + } +} + +const Expr &IrNode::operand(int i) { + CHECK_LT(i, operands.size()); + return operands[i]; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h new file mode 100644 index 0000000000000..b1baf1d59fdea --- /dev/null +++ b/paddle/cinn/ir/ir_base.h @@ -0,0 +1,500 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/common/object.h" +#include "cinn/common/shared.h" +#include "cinn/common/type.h" + +namespace cinn { + +namespace ir { +using common::BFloat16; +using common::Float; +using common::Float16; +using common::Int; +using common::Type; +using common::type_of; + +class Module; +class IRVisitor; +class _Buffer_; +class Buffer; +class _Module_; +class _LoweredFunc_; +class LoweredFunc; +class _Tensor_; +class Tensor; +class _Var_; +class Var; +class _BufferRange_; +class BufferRange; +class ScheduleBlock; +class ScheduleBlockRealize; + +// clang-format off +#define NODETY_PRIMITIVE_TYPE_FOR_EACH(macro__) \ + macro__(IntImm) \ + macro__(UIntImm) \ + macro__(FloatImm) \ + macro__(StringImm) \ + +#define NODETY_BINARY_OP_FOR_EACH(macro__) \ + macro__(Add) \ + macro__(Sub) \ + macro__(Mul) \ + macro__(Div) \ + macro__(Mod) \ + macro__(EQ) \ + macro__(NE) \ + macro__(LT) \ + macro__(LE) \ + macro__(GT) \ + macro__(GE) \ + macro__(And) \ + macro__(Or) \ + macro__(Min) \ + macro__(Max) \ + +#define NODETY_UNARY_OP_FOR_EACH(macro__) \ + macro__(Minus) \ + macro__(Not) \ + +#define NODETY_OP_FOR_EACH(macro__) NODETY_BINARY_OP_FOR_EACH(macro__) NODETY_UNARY_OP_FOR_EACH(macro__) + +#define NODETY_CONTROL_OP_FOR_EACH(macro__) \ + macro__(Cast) \ + macro__(For) \ + macro__(PolyFor) \ + macro__(Select) \ + macro__(IfThenElse) \ + macro__(Block) \ + macro__(Call) \ + macro__(_Var_) \ + macro__(Load) \ + macro__(Store) \ + macro__(Alloc) \ + macro__(Free) \ + macro__(_Buffer_) \ + macro__(_Tensor_) \ + macro__(_LoweredFunc_) \ + macro__(_Module_) \ + macro__(Let) \ + macro__(Reduce) \ + macro__(Ramp) \ + macro__(Broadcast) \ + macro__(FracOp) \ + macro__(Product) \ + macro__(Sum) \ + macro__(PrimitiveNode) \ + macro__(IntrinsicOp) \ + macro__(_BufferRange_) \ + macro__(ScheduleBlock) \ + macro__(ScheduleBlockRealize) \ + + +#define NODETY_FORALL(__m) \ + NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ + NODETY_OP_FOR_EACH(__m) \ + NODETY_CONTROL_OP_FOR_EACH(__m) +// clang-format on + +//! Define IrNodeTy +// @{ +#define __m(x__) x__, +enum class IrNodeTy { kUnk = -1, NODETY_FORALL(__m) }; +#undef __m +// @} + +//! String representations for IrNodeTy. +// @{ +#define __m(x__) #x__, +const std::vector kIrNodeTyReprs({NODETY_FORALL(__m) "None"}); +#undef __m +// @} + +std::ostream& operator<<(std::ostream& os, IrNodeTy type); + +struct Expr; + +/** + * The base of all the nodes in the IR. + */ +class IrNode : public common::Object { + public: + //! The operands of this operator. + std::vector operands; + + IrNode() = default; + explicit IrNode(Type t) : type_(t) {} + virtual ~IrNode() = default; + + virtual IrNodeTy node_type() const { return IrNodeTy::kUnk; } + virtual Type type() const { return type_; } + void set_type(Type type) { type_ = type; } + + //! Get i-th operand + const Expr& operand(int i); + + //! Gather all the expression fields in this node for easier visit and mutate. + virtual std::vector expr_fields() { return {}; } + virtual std::vector expr_fields() const { return {}; } + + const char* type_info() const override { return __type_info__; } + + //! Verify the current IR node's correctness. + virtual void Verify() const { CINN_NOT_IMPLEMENTED } + + protected: + static constexpr char* __type_info__ = "IRNode"; + Type type_; +}; + +/** + * A handle to store any IRNode. + */ +class IrNodeRef : public common::Shared { + public: + IrNodeRef() = default; + IrNodeRef(const IrNodeRef& other) : Shared(other.p_) {} + explicit IrNodeRef(IrNode* x) : Shared(x) {} + + virtual IrNodeTy node_type() const { return operator->()->node_type(); } + + template + const T* As() const { + static_assert(std::is_base_of()); + CHECK(get()) << "IrNodeRef holds null"; + if (node_type() == T::_node_type_) return static_cast(get()); + return nullptr; + } + template + T* As() { + if (node_type() == T::_node_type_) return static_cast(get()); + return nullptr; + } + + void operator=(const IrNodeRef& other) { + *static_cast*>(this) = *static_cast*>(&other); + } + + IrNode* ptr() { return get(); } + IrNode* ptr() const { return get(); } +}; + +template +struct ExprNode : public IrNode { + ExprNode() : IrNode(Type()) {} + explicit ExprNode(Type t) : IrNode(t) { set_type(t); } + explicit ExprNode(int num_operands) { operands().resize(num_operands); } + + T* self() { return static_cast(this); } + const T* const_self() const { return dynamic_cast(this); } + + const std::vector& operands() const { return IrNode::operands; } + std::vector& operands() { return IrNode::operands; } + + Expr& operand(int i) { + CHECK_LT(i, operands().size()); + return operands()[i]; + } + const Expr& operand(int i) const { + CHECK_LT(i, operands().size()); + return operands()[i]; + } + + virtual Expr Copy() const; + + IrNodeTy node_type() const override { return T::_node_type_; } +}; + +struct IntImm : public ExprNode { + int64_t value; + + IntImm(Type t, int64_t v) : ExprNode(t), value(v) { Verify(); } + + void Verify() const override { + CHECK(type().is_int()); + CHECK(type().is_scalar()); + CHECK(type().bits() == 8 || type().bits() == 16 || type().bits() == 32 || type().bits() == 64); + } + + static const IrNodeTy _node_type_ = IrNodeTy::IntImm; +}; + +struct UIntImm : public ExprNode { + uint64_t value; + + UIntImm(Type t, uint64_t v) : ExprNode(t), value(v) { Verify(); } + + void Verify() const override { + CHECK(type().is_uint()); + CHECK(type().is_scalar()); + CHECK(type().bits() == 1 /*bool*/ || type().bits() == 8 || type().bits() == 16 || type().bits() == 32 || + type().bits() == 64); + } + + static const IrNodeTy _node_type_ = IrNodeTy::UIntImm; +}; + +struct FloatImm : public ExprNode { + double value; + + FloatImm(Type t, double v) : ExprNode(t), value(v) { Verify(); } + + void Verify() const override { + CHECK(type().is_float()); + CHECK(type().is_scalar()); + } + + static const IrNodeTy _node_type_ = IrNodeTy::FloatImm; +}; + +struct StringImm : public ExprNode { + std::string value; + + explicit StringImm(const std::string& value) : value(value) { Verify(); } + + void Verify() const override {} + + static const IrNodeTy _node_type_ = IrNodeTy::StringImm; +}; + +class Var; +/** + * An expression that represents some value or the result of some operations. + */ +struct Expr : public IrNodeRef { + public: + Expr() = default; + Expr(const Expr& other) : IrNodeRef(other.ptr()) {} + Expr(IrNode* p) : IrNodeRef(p) {} // NOLINT + explicit Expr(const Var& var); + + //! Helper function to construct numeric constants of various types. + // @{ + explicit Expr(bool x) : IrNodeRef(new UIntImm(UInt(1), x)) {} + + explicit Expr(int8_t x) : IrNodeRef(new IntImm(Int(8), x)) {} + explicit Expr(int16_t x) : IrNodeRef(new IntImm(Int(16), x)) {} + explicit Expr(int32_t x) : IrNodeRef(new IntImm(Int(32), x)) {} + explicit Expr(int64_t x) : IrNodeRef(new IntImm(Int(64), x)) {} + + explicit Expr(uint8_t x) : IrNodeRef(new UIntImm(UInt(8), x)) {} + explicit Expr(uint16_t x) : IrNodeRef(new UIntImm(UInt(16), x)) {} + explicit Expr(uint32_t x) : IrNodeRef(new UIntImm(UInt(32), x)) {} + explicit Expr(uint64_t x) : IrNodeRef(new UIntImm(UInt(64), x)) {} + + explicit Expr(cinn::common::bfloat16 x) : IrNodeRef(new FloatImm(BFloat16(), x)) {} + explicit Expr(cinn::common::float16 x) : IrNodeRef(new FloatImm(Float16(), x)) {} + explicit Expr(float x) : IrNodeRef(new FloatImm(Float(32), x)) {} + explicit Expr(double x) : IrNodeRef(new FloatImm(Float(64), x)) {} + + explicit Expr(const std::string& x) : IrNodeRef(new StringImm(x)) {} + // @} + + Expr& operator=(const Expr& other); + + // primitive types + // @{ + bool as_bool() const; + + int8_t as_int8() const; + int16_t as_int16() const; + int32_t as_int32() const; + int64_t as_int64() const; + + uint8_t as_uint8() const; + uint16_t as_uint16() const; + uint32_t as_uint32() const; + uint64_t as_uint64() const; + + cinn::common::bfloat16 as_bfloat16() const; + cinn::common::float16 as_float16() const; + float as_float() const; + double as_double() const; + // @} + + _Var_* as_var(); + const _Var_* as_var() const; + Var as_var_ref() const; + + // @{ Other nodes caster. + _Buffer_* as_buffer(); + const _Buffer_* as_buffer() const; + Buffer as_buffer_ref() const; + + _LoweredFunc_* as_lowered_func(); + const _LoweredFunc_* as_lowered_func() const; + LoweredFunc as_lowered_func_ref() const; + + _Module_* as_module(); + const _Module_* as_module() const; + ir::Module as_module_ref() const; + + _Tensor_* as_tensor(); + const _Tensor_* as_tensor() const; + ir::Tensor as_tensor_ref() const; + // @} + + bool is_constant() const; + double get_constant() const; + + //! Tell if this is a compare op. + bool is_cmp() const; + + bool is_var() const; + + operator Var(); + + Type type() const { return p_->type(); } +}; + +template +struct UnaryOpNode : public ExprNode { + UnaryOpNode() { operands().resize(1); } + UnaryOpNode(Type type, Expr v) : ExprNode(type) { + CHECK(v.defined()); + operands().resize(1); + this->v() = v; + } + + Type type() const override { + CHECK(v().defined()); + return v().type(); + } + + Expr& v() { return operands().front(); } + const Expr& v() const { return operands().front(); } + + std::vector expr_fields() override { return {&v()}; } + std::vector expr_fields() const override { return {&v()}; } + + using ExprNode::operands; +}; + +template +struct BinaryOpNode : public ExprNode { + BinaryOpNode() { operands().resize(2); } + BinaryOpNode(Type type, Expr a, Expr b) : ExprNode(type) { + CHECK(type.valid()); + CHECK(a.defined()); + CHECK(b.defined()); + operands().resize(2); + this->a() = a; + this->b() = b; + // CHECK_EQ(a.type(), b.type()) << "the type of two argument not match"; + } + + Expr& a() { return ExprNode::operand(0); } + Expr& b() { return ExprNode::operand(1); } + const Expr& a() const { return ExprNode::operand(0); } + const Expr& b() const { return ExprNode::operand(1); } + + Type type() const override { return a().type(); } + + std::vector expr_fields() override { return {&a(), &b()}; } + std::vector expr_fields() const override { return {&a(), &b()}; } + + using ExprNode::operands; +}; + +//! Zero in CINN type system. +Expr Zero(const Type& type); +Expr One(const Type& type); + +#define DEVICE_API_FOR_ALL(__) \ + __(UNK) \ + __(Host) \ + __(GPU) \ + __(CUDA) \ + __(OpenCL) + +#define __decl__(x) x, +enum class DeviceAPI { DEVICE_API_FOR_ALL(__decl__) }; +#undef __decl__ + +static std::ostream& operator<<(std::ostream& os, DeviceAPI x) { + switch (x) { +#define __decl__(x) \ + case DeviceAPI::x: \ + os << #x; \ + break; + + DEVICE_API_FOR_ALL(__decl__) +#undef __decl__ + + default: + break; + } + return os; +} + +#define MEMORY_TYPE_FOR_ALL(__) \ + __(Auto, "Auto") \ + __(Heap, "Heap") \ + __(Stack, "Stack") \ + __(GPUShared, "GPUShared") \ + __(GPULocal, "GPULocal") \ +/** \ + * An enum describing different address spaces to be used with Func::store_in. \ + */ +enum class MemoryType { +#define __(token__, token_repr__) token__, + MEMORY_TYPE_FOR_ALL(__) +#undef __ +}; + +static std::ostream& operator<<(std::ostream& os, MemoryType t) { + switch (t) { +#define __(token__, token_repr__) \ + case MemoryType::token__: \ + os << token_repr__; \ + break; + + MEMORY_TYPE_FOR_ALL(__) + + default: + LOG(FATAL) << "Not supported memory type"; +#undef __ + } + return os; +} + +template +Expr ExprNode::Copy() const { + LOG(FATAL) << "Not Implemented"; + return Expr(); +} + +} // namespace ir +} // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::ir::Expr& x) { return reinterpret_cast(x.get()); } +}; + +} // namespace std diff --git a/paddle/cinn/ir/ir_compare.cc b/paddle/cinn/ir/ir_compare.cc new file mode 100644 index 0000000000000..16a0672d51fea --- /dev/null +++ b/paddle/cinn/ir/ir_compare.cc @@ -0,0 +1,319 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_compare.h" + +#include + +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" + +namespace cinn { +namespace ir { + +bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { + if (lhs.get() == rhs.get()) { // the same object, including both are null + return true; + } + + if (!lhs.defined() || !rhs.defined()) { // someone invalid + return false; + VLOG(5) << "Not equal on Expr, someone not defined"; + } + bool equal = lhs->node_type() == rhs->node_type(); + equal = equal && IRVisitorBase::Visit(&lhs, &rhs); + + if (!equal) { + VLOG(5) << "Not equal on Expr, lhs:[type:" << kIrNodeTyReprs[static_cast(lhs->node_type())] << "]\n" + << lhs << ", \nrhs[type:" << kIrNodeTyReprs[static_cast(rhs->node_type())] << "]\n" + << rhs; + } + return equal; +} + +bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs, bool allow_name_suffix_diff) { + // if allow_name_suffix_diff=true then just compare the name prefix before the "_[0-9]+" + auto common_len = 0; + for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) { + if (lhs[common_len] != rhs[common_len]) break; + } + + auto is_endswith_index = [&common_len](const std::string& name) { + const std::regex txt_regex("_\\d+"); + return common_len == name.size() || std::regex_match(name.substr(common_len), txt_regex); + }; + + bool equal = false; + if (common_len == lhs.size() && common_len == rhs.size()) { + equal = true; + } else { + equal = false; + if (allow_name_suffix_diff) { + equal = is_endswith_index(lhs) && is_endswith_index(rhs); + } + } + + if (!equal) { + VLOG(5) << "Not euqal on name, lhs=" << lhs << ", rhs=" << rhs; + } + + return equal; +} + +bool IrEqualVisitor::Compare(const std::map& lhs, const std::map& rhs) { + if (lhs.size() != rhs.size()) { + VLOG(6) << "Not equal on attrs, lhs size=" << lhs.size() << ", rhs size=" << rhs.size(); + return false; + } + for (auto&& kv : lhs) { + auto opposite = rhs.find(kv.first); + if (opposite == rhs.end() || kv.second != opposite->second) { + VLOG(6) << "Not equal at attr key=" << kv.first; + return false; + } + } + return true; +} + +template +bool IrEqualVisitor::Compare(const std::vector& lhs, const std::vector& rhs) { + if (lhs.size() != rhs.size()) { + VLOG(6) << "Not equal on repeated fields, lhs size=" << lhs.size() << ", rhs size=" << rhs.size(); + return false; + } + for (auto i = 0; i < lhs.size(); ++i) { + if (!Compare(lhs.at(i), rhs.at(i))) { + VLOG(6) << "Not equal on repeated fields at index=" << i; + return false; + } + } + return true; +} + +#define PRIMITIVE_TYPE_IMPL(op__) \ + bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ + auto* rhs = other->As(); \ + return lhs->value == rhs->value; \ + } + +#define UNARY_OP_IMPL(op__) \ + bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ + auto* rhs = other->As(); \ + return Compare(lhs->v(), rhs->v()); \ + } + +#define BINARY_OP_IMPL(op__) \ + bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ + auto* rhs = other->As(); \ + return Compare(lhs->a(), rhs->a()) && Compare(lhs->b(), rhs->b()); \ + } + +NODETY_PRIMITIVE_TYPE_FOR_EACH(PRIMITIVE_TYPE_IMPL) +NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL) +NODETY_BINARY_OP_FOR_EACH(BINARY_OP_IMPL) + +#undef PRIMITIVE_TYPE_IMPL +#undef UNARY_OP_IMPL +#undef BINARY_OP_IMPL + +bool IrEqualVisitor::Visit(const Cast* lhs, const Expr* other) { + auto* rhs = other->As(); + return lhs->type() == rhs->type() && Compare(lhs->v(), rhs->v()); +} + +bool IrEqualVisitor::Visit(const For* lhs, const Expr* other) { + auto* rhs = other->As(); + return lhs->for_type() == rhs->for_type() && Compare(lhs->loop_var, rhs->loop_var) && Compare(lhs->min, rhs->min) && + Compare(lhs->extent, rhs->extent) && Compare(lhs->body, rhs->body); +} + +bool IrEqualVisitor::Visit(const PolyFor* lhs, const Expr* other) { + auto* rhs = other->As(); + return lhs->for_type() == rhs->for_type() && Compare(lhs->iterator, rhs->iterator) && Compare(lhs->init, rhs->init) && + Compare(lhs->condition, rhs->condition) && Compare(lhs->inc, rhs->inc) && Compare(lhs->body, rhs->body); +} + +bool IrEqualVisitor::Visit(const Select* lhs, const Expr* other) { + auto* rhs = other->As(); + IRVisitorBase::Visit(&node->condition, &node->condition); + IRVisitorBase::Visit(&node->true_value, &node->true_value); + IRVisitorBase::Visit(&node->false_value, &node->false_value); +} +template +void IRMutator::Visit(const IfThenElse *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->condition, &node->condition); + IRVisitorBase::Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) IRVisitorBase::Visit(&node->false_case, &node->false_case); +} +template +void IRMutator::Visit(const Block *expr, T op) { + auto *node = op->template As(); + for (auto &expr : node->stmts) { + IRVisitorBase::Visit(&expr, &expr); + } +} +template +void IRMutator::Visit(const Call *expr, T op) { + auto *node = op->template As(); + for (auto &expr : node->read_args) { + IRVisitorBase::Visit(&expr, &expr); + } + for (auto &expr : node->write_args) { + IRVisitorBase::Visit(&expr, &expr); + } +} +template +void IRMutator::Visit(const _Module_ *expr, T op) { + auto *node = op->template As<_Module_>(); + for (auto &func : node->functions) { + IRVisitorBase::Visit(&func, &func); + } + for (auto &func : node->buffers) { + IRVisitorBase::Visit(&func, &func); + } + for (auto &expr : node->submodules) { + IRVisitorBase::Visit(&expr, &expr); + } +} +template +void IRMutator::Visit(const _Var_ *expr, T op) { + auto *node = op->template As(); + if (node->lower_bound.defined()) { + IRVisitorBase::Visit(&node->lower_bound, &node->lower_bound); + } + if (node->upper_bound.defined()) { + IRVisitorBase::Visit(&node->upper_bound, &node->upper_bound); + } +} +template +void IRMutator::Visit(const Load *expr, T op) { + auto *node = op->template As(); + for (auto &idx : node->indices) IRVisitorBase::Visit(&idx, &idx); + IRVisitorBase::Visit(&node->tensor, &node->tensor); +} +template +void IRMutator::Visit(const Store *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->value, &node->value); + IRVisitorBase::Visit(&node->tensor, &node->tensor); + for (auto &idx : node->indices) IRVisitorBase::Visit(&idx, &idx); +} +template +void IRMutator::Visit(const Alloc *expr, T op) { + auto *node = op->template As(); + for (auto &e : node->extents) { + IRVisitorBase::Visit(&e, &e); + } + + if (node->condition.defined()) IRVisitorBase::Visit(&node->condition, &node->condition); + if (node->body.defined()) { + Expr body(node->body); + IRVisitorBase::Visit(&node->body, &body); + } +} +template +void IRMutator::Visit(const Free *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->destination, &node->destination); +} +template +void IRMutator::Visit(const _Buffer_ *expr, T op) { + auto *node = op->template As<_Buffer_>(); + + for (auto &e : node->shape) { + IRVisitorBase::Visit(&e, &e); + } + for (auto &e : node->strides) { + IRVisitorBase::Visit(&e, &e); + } + IRVisitorBase::Visit(&node->elem_offset, &node->elem_offset); +} +template +void IRMutator::Visit(const _Tensor_ *expr, T op) { + auto *node = op->template As<_Tensor_>(); + + for (auto &e : node->shape) { + IRVisitorBase::Visit(&e, &e); + } +} +template +void IRMutator::Visit(const _LoweredFunc_ *expr, T op) { + auto *node = op->template As<_LoweredFunc_>(); + IRVisitorBase::Visit(&node->body, &node->body); +} +template +void IRMutator::Visit(const Let *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->symbol, &node->symbol); + if (node->body.defined()) IRVisitorBase::Visit(&node->body, &node->body); +} +template +void IRMutator::Visit(const Reduce *expr, T op) { + auto *node = op->template As(); + if (node->init.defined()) IRVisitorBase::Visit(&node->init, &node->init); + CHECK(node->body.defined()); + IRVisitorBase::Visit(&node->body, &node->body); +} + +template +void IRMutator::Visit(const Ramp *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->base, &node->base); + IRVisitorBase::Visit(&node->stride, &node->stride); +} + +template +void IRMutator::Visit(const Broadcast *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->value, &node->value); +} + +template +void IRMutator::Visit(const FracOp *expr, T op) { + auto *node = op->template As(); + IRVisitorBase::Visit(&node->a(), &node->a()); + IRVisitorBase::Visit(&node->b(), &node->b()); +} + +template +void IRMutator::Visit(const Product *expr, T op) { + auto *node = op->template As(); + for (auto &x : node->operands()) { + IRVisitorBase::Visit(&x, &x); + } +} + +template +void IRMutator::Visit(const Sum *expr, T op) { + auto *node = op->template As(); + for (auto &x : node->operands()) { + IRVisitorBase::Visit(&x, &x); + } +} +template +void IRMutator::Visit(const PrimitiveNode *expr, T op) { + auto *node = op->template As(); + for (auto &args : node->arguments) { + for (auto &arg : args) { + IRVisitorBase::Visit(&arg, &arg); + } + } +} + +template +void IRMutator::Visit(const IntrinsicOp *expr, T op) { + auto *node = op->template As(); + switch (node->getKind()) { + case ir::IntrinsicKind::kBufferGetDataHandle: { + auto *n = llvm::dyn_cast(node); + Visit(&n->buffer, &n->buffer); + } break; + case ir::IntrinsicKind::kBufferGetDataConstHandle: { + auto *n = llvm::dyn_cast(node); + Visit(&n->buffer, &n->buffer); + } break; + case ir::IntrinsicKind::kPodValueToX: { + auto *n = llvm::dyn_cast(node); + Visit(&n->pod_value_ptr, &n->pod_value_ptr); + } break; + case ir::IntrinsicKind::kBuiltinIntrin: { + auto *n = llvm::dyn_cast(node); + for (auto &expr : n->args) { + Visit(&expr, &expr); + } + } break; + } +} + +template +void IRMutator::Visit(const _BufferRange_ *expr, T op) { + auto *node = op->template As<_BufferRange_>(); + CHECK(node); + IRVisitorBase::Visit(&node->buffer, &node->buffer); + for (auto &var : node->ranges) { + if (var->lower_bound.defined()) { + IRVisitorBase::Visit(&var->lower_bound, &var->lower_bound); + } + if (var->upper_bound.defined()) { + IRVisitorBase::Visit(&var->upper_bound, &var->upper_bound); + } + } +} + +template +void IRMutator::Visit(const ScheduleBlock *expr, T op) { + auto *node = op->template As(); + CHECK(node); + for (auto &var : node->iter_vars) { + if (var->lower_bound.defined()) { + IRVisitorBase::Visit(&var->lower_bound, &var->lower_bound); + } + if (var->upper_bound.defined()) { + IRVisitorBase::Visit(&var->upper_bound, &var->upper_bound); + } + } + for (auto &buffer_region : node->read_buffers) { + IRVisitorBase::Visit(&buffer_region, &buffer_region); + } + for (auto &buffer_region : node->write_buffers) { + IRVisitorBase::Visit(&buffer_region, &buffer_region); + } + IRVisitorBase::Visit(&(node->body), &(node->body)); +} + +template +void IRMutator::Visit(const ScheduleBlockRealize *expr, T op) { + auto *node = op->template As(); + CHECK(node); + for (auto &value : node->iter_values) { + IRVisitorBase::Visit(&value, &value); + } + IRVisitorBase::Visit(&node->schedule_block, &node->schedule_block); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_operators.cc b/paddle/cinn/ir/ir_operators.cc new file mode 100644 index 0000000000000..cc586971c11b2 --- /dev/null +++ b/paddle/cinn/ir/ir_operators.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_operators.h" + +#include +#include + +#include "cinn/common/target.h" +#include "cinn/common/type.h" +#include "cinn/hlir/op/op_util.h" +#include "cinn/lang/compute.h" +#include "cinn/runtime/flags.h" + +namespace cinn { +namespace ir { +using attr_t = absl::variant; + +Expr operator<<(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + auto int_a = a.As(); + auto int_b = b.As(); + Type t_a = a.type(); + Type t_b = b.type(); + if (t_a.is_index_type() && t_b.is_index_type()) { + if (int_b) { + CHECK(int_b->value >= 0 && int_b->value < t_a.bits()) + << "Shift amount must be non-negative and less than " << t_a.bits() << " for type " << t_a << std::endl; + if (int_b->value == 0) return a; + } + if (int_a && int_b) { + return Expr(int_a->value << int_b->value); + } + } + return lang::CallExtern("left_shift", {a, b}, {{"vectorizable", false}}); +} + +Expr operator>>(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + auto int_a = a.As(); + auto int_b = b.As(); + Type t_a = a.type(); + Type t_b = b.type(); + if (t_a.is_index_type() && t_b.is_index_type()) { + if (int_b) { + CHECK(int_b->value >= 0 && int_b->value < t_a.bits()) + << "Shift amount must be non-negative and less than " << t_a.bits() << " for type " << t_a << std::endl; + if (int_b->value == 0) return a; + } + if (int_a && int_b) { + return Expr(int_a->value >> int_b->value); + } + } + return lang::CallExtern("right_shift", {a, b}, {{"vectorizable", false}}); +} + +Expr operator|(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + auto int_a = a.As(); + auto int_b = b.As(); + Type t_a = a.type(); + Type t_b = b.type(); + if (t_a.is_index_type() && t_b.is_index_type()) { + if (int_a && int_b) { + return Expr(int_a->value | int_b->value); + } + } + auto target = cinn::runtime::CurrentTarget::GetCurrentTarget(); + if (target.arch == common::Target::Arch::X86) { + return lang::CallExtern("bitwise_or", {a, b}, {{"vectorizable", false}}); + } else if (target.arch == common::Target::Arch::NVGPU) { + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_or"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_or."; + } +} + +Expr operator&(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + auto int_a = a.As(); + auto int_b = b.As(); + Type t_a = a.type(); + Type t_b = b.type(); + if (t_a.is_index_type() && t_b.is_index_type()) { + if (int_a && int_b) { + return Expr(int_a->value & int_b->value); + } + } + auto target = cinn::runtime::CurrentTarget::GetCurrentTarget(); + if (target.arch == common::Target::Arch::X86) { + return lang::CallExtern("bitwise_and", {a, b}, {{"vectorizable", false}}); + } else if (target.arch == common::Target::Arch::NVGPU) { + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_and"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_and."; + } +} + +Expr operator^(Expr a, Expr b) { + CHECK(a.type().is_int() || a.type().is_uint()); + CHECK(b.type().is_int() || b.type().is_uint()); + auto int_a = a.As(); + auto int_b = b.As(); + Type t_a = a.type(); + Type t_b = b.type(); + if (t_a.is_index_type() && t_b.is_index_type()) { + if (int_a && int_b) { + return Expr(int_a->value ^ int_b->value); + } + } + auto target = cinn::runtime::CurrentTarget::GetCurrentTarget(); + if (target.arch == common::Target::Arch::X86) { + return lang::CallExtern("bitwise_xor", {a, b}, {{"vectorizable", false}}); + } else if (target.arch == common::Target::Arch::NVGPU) { + auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_xor"); + return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_xor."; + } +} + +Expr operator~(Expr a) { + CHECK(a.type().is_int() || a.type().is_uint()); + auto target = cinn::runtime::CurrentTarget::GetCurrentTarget(); + if (target.arch == common::Target::Arch::X86) { + return lang::CallExtern("bitwise_not", {a}, {{"vectorizable", false}}); + } else if (target.arch == common::Target::Arch::NVGPU) { + auto func_name = hlir::GetExternFuncName(target, a->type(), "bitwise_not"); + return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_not."; + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_operators.h b/paddle/cinn/ir/ir_operators.h new file mode 100644 index 0000000000000..a2a7b711573aa --- /dev/null +++ b/paddle/cinn/ir/ir_operators.h @@ -0,0 +1,133 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +//-- left hand -- +template ::value>::type> +Expr operator+(Expr a, POD b) { + return Add::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator-(Expr a, POD b) { + return Sub::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator*(Expr a, POD b) { + return Mul::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator/(Expr a, POD b) { + return Div::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator%(Expr a, POD b) { + return Mod::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator<(Expr a, POD b) { + return LT::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator<=(Expr a, POD b) { + return LE::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator>(Expr a, POD b) { + return GT::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator>=(Expr a, POD b) { + return GE::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator==(Expr a, POD b) { + return EQ::Make(Expr(a), Expr(b)); +} + +//- right hand -- +template ::value>::type> +Expr operator+(POD a, Expr b) { + return Add::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator-(POD a, Expr b) { + return Sub::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator*(POD a, Expr b) { + return Mul::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator/(POD a, Expr b) { + return Div::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator%(POD a, Expr b) { + return Mod::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator<(POD a, Expr b) { + return LT::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator<=(POD a, Expr b) { + return LE::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator>(POD a, Expr b) { + return GT::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator>=(POD a, Expr b) { + return GE::Make(Expr(a), Expr(b)); +} +template ::value>::type> +Expr operator==(POD a, Expr b) { + return EQ::Make(Expr(a), Expr(b)); +} + +//-- +inline Expr operator+(Expr a, Expr b) { return Add::Make(a, b); } +inline Expr operator-(Expr a, Expr b) { return Sub::Make(a, b); } +inline Expr operator*(Expr a, Expr b) { return Mul::Make(a, b); } +inline Expr operator/(Expr a, Expr b) { return Div::Make(a, b); } +inline Expr operator%(Expr a, Expr b) { return Mod::Make(a, b); } + +inline Expr operator&&(Expr a, Expr b) { return And::Make(Expr(a), Expr(b)); } +inline Expr operator||(Expr a, Expr b) { return Or::Make(Expr(a), Expr(b)); } +inline Expr operator>=(Expr a, Expr b) { return GE::Make(Expr(a), Expr(b)); } +inline Expr operator<=(Expr a, Expr b) { return LE::Make(Expr(a), Expr(b)); } +inline Expr operator>(Expr a, Expr b) { return GT::Make(Expr(a), Expr(b)); } +inline Expr operator<(Expr a, Expr b) { return LT::Make(Expr(a), Expr(b)); } + +inline Expr operator-(Expr a) { return Minus::Make(Expr(a)); } +inline Expr operator!(Expr a) { return Not::Make(Expr(a)); } + +Expr operator<<(Expr a, Expr b); +Expr operator>>(Expr a, Expr b); +Expr operator^(Expr a, Expr b); +Expr operator|(Expr a, Expr b); +Expr operator&(Expr a, Expr b); +Expr operator~(Expr a); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_operators_test.cc b/paddle/cinn/ir/ir_operators_test.cc new file mode 100644 index 0000000000000..b31614308e889 --- /dev/null +++ b/paddle/cinn/ir/ir_operators_test.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_operators.h" + +#include + +namespace cinn { +namespace ir { + +TEST(ir_operators, test) { + Expr a(1); + Expr b = a + 1; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_printer.cc b/paddle/cinn/ir/ir_printer.cc new file mode 100644 index 0000000000000..66604da970182 --- /dev/null +++ b/paddle/cinn/ir/ir_printer.cc @@ -0,0 +1,645 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_printer.h" + +#include +#include +#include +#include + +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/module.h" +#include "cinn/ir/tensor.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +using common::bfloat16; +using common::float16; + +void IrPrinter::Print(Expr e) { IRVisitor::Visit(&e); } +void IrPrinter::Print(const std::vector &exprs, const std::string &splitter) { + for (std::size_t i = 0; !exprs.empty() && i + 1 < exprs.size(); i++) { + Print(exprs[i]); + os_ << splitter; + } + if (!exprs.empty()) Print(exprs.back()); +} + +void IrPrinter::Visit(const IntImm *x) { + if (x->type().is_int(64)) { + os_ << x->value << "ll"; + } else if (x->type().is_int(32)) { + os_ << x->value; + } else if (x->type().is_int(16)) { + os_ << "(int16_t)" << x->value; + } else if (x->type().is_int(8)) { + os_ << "(int8_t)" << x->value; + } else { + LOG(FATAL) << "Not support int type: " << x->type(); + } +} +void IrPrinter::Visit(const UIntImm *x) { + if (x->type().is_uint(64)) { + os_ << x->value << "ull"; + } else if (x->type().is_uint(32)) { + os_ << x->value; + } else if (x->type().is_uint(16)) { + os_ << "(uint16_t)" << x->value; + } else if (x->type().is_uint(8)) { + os_ << "(uint8_t)" << x->value; + } else if (x->type().is_uint(1)) { + if (x->value) { + os_ << "true"; + } else { + os_ << "false"; + } + } else { + LOG(FATAL) << "Not support uint type: " << x->type(); + } +} +void IrPrinter::Visit(const FloatImm *x) { + if (x->type().is_float16()) { + if (std::isinf(x->value)) { + os_ << "cinn::common::raw_uint16_to_float16(0x7c00)"; + } else if (std::isnan(x->value)) { + os_ << "cinn::common::raw_uint16_to_float16(0x7e00)"; + } else { + os_ << "(float16)" << std::setprecision(std::numeric_limits::max_digits10) + << static_cast(x->value) << "f"; + } + } else if (x->type().is_bfloat16()) { + if (std::isinf(x->value)) { + os_ << "cinn::common::raw_uint16_to_bfloat16(0x7F80)"; + } else if (std::isnan(x->value)) { + os_ << "cinn::common::raw_uint16_to_bfloat16(0x7FC0)"; + } else { + os_ << "(bfloat16)" << std::setprecision(std::numeric_limits::max_digits10) + << static_cast(x->value) << "f"; + } + } else if (x->type().is_float(32)) { + os_ << std::setprecision(std::numeric_limits::max_digits10) << std::showpoint << x->value; + if (std::isfinite(x->value)) { + os_ << "f"; + } + } else if (x->type().is_float(64)) { + os_ << std::setprecision(std::numeric_limits::max_digits10) << std::showpoint << x->value; + } else { + LOG(FATAL) << "Not support float type: " << x->type(); + } +} +void IrPrinter::Visit(const StringImm *x) { os_ << "\"" << x->value << "\""; } +void IrPrinter::Visit(const Add *x) { PrintBinaryOp("+", x); } +void IrPrinter::Visit(const Sub *x) { PrintBinaryOp("-", x); } +void IrPrinter::Visit(const Mul *x) { PrintBinaryOp("*", x); } +void IrPrinter::Visit(const Div *x) { PrintBinaryOp("/", x); } +void IrPrinter::Visit(const Mod *x) { PrintBinaryOp("%", x); } +void IrPrinter::Visit(const EQ *x) { PrintBinaryOp("==", x); } +void IrPrinter::Visit(const NE *x) { PrintBinaryOp("!=", x); } +void IrPrinter::Visit(const LT *x) { PrintBinaryOp("<", x); } +void IrPrinter::Visit(const LE *x) { PrintBinaryOp("<=", x); } +void IrPrinter::Visit(const GT *x) { PrintBinaryOp(">", x); } +void IrPrinter::Visit(const GE *x) { PrintBinaryOp(">=", x); } +void IrPrinter::Visit(const And *x) { PrintBinaryOp("and", x); } +void IrPrinter::Visit(const Or *x) { PrintBinaryOp("or", x); } +void IrPrinter::Visit(const Not *x) { + os_ << "!"; + Print(x->v()); +} +void IrPrinter::Visit(const Min *x) { + os_ << "cinn_min("; + Print(x->a()); + os_ << ", "; + Print(x->b()); + os_ << ")"; +} +void IrPrinter::Visit(const Max *x) { + os_ << "cinn_max("; + Print(x->a()); + os_ << ", "; + Print(x->b()); + os_ << ")"; +} +void IrPrinter::Visit(const Minus *x) { + os_ << "-("; + Print(x->v()); + os_ << ")"; +} +void IrPrinter::Visit(const For *x) { + if (x->is_parallel()) { + os() << "parallel for ("; + } else if (x->is_unrolled()) { + os() << "unroll for ("; + } else if (x->is_vectorized()) { + int factor = x->vectorize_info().factor; + os() << "vectorize[" << factor << "] for ("; + } else if (x->is_binded()) { + auto &bind_info = x->bind_info(); + if (bind_info.valid()) { + char axis_name = 'x' + bind_info.offset; + auto for_type = bind_info.for_type; + std::string prefix = for_type == ForType::GPUBlock ? "blockIdx." : "threadIdx."; + os() << "thread_bind[" << prefix << axis_name << "] for ("; + } else { + os() << "thread_bind[invalid info] for ("; + } + } else if (x->is_serial()) { + os() << "serial for ("; + } else if (x->is_default()) { + os() << "default for ("; + } else { + os() << "for ("; + } + Print(x->loop_var); + os_ << ", "; + Print(x->min); + os_ << ", "; + Print(x->extent); + os_ << ")\n"; + + DoIndent(); + Print(x->body); +} + +void IrPrinter::Visit(const PolyFor *x) { + if (x->is_parallel()) { + os() << "parallel poly_for ("; + } else { + os() << "poly_for ("; + } + Print(x->iterator); + os_ << ", "; + Print(x->init); + os_ << ", "; + Print(x->condition); + os_ << ", "; + Print(x->inc); + os_ << ")\n"; + + DoIndent(); + Print(x->body); +} +void IrPrinter::Visit(const IfThenElse *x) { + os_ << "if ("; + Print(x->condition); + os_ << ") {\n"; + IncIndent(); + DoIndent(); + Print(x->true_case); + DecIndent(); + os() << "\n"; + DoIndent(); + os() << "}"; + + if (x->false_case.defined()) { + os_ << " else {\n"; + IncIndent(); + + DoIndent(); + Print(x->false_case); + os() << "\n"; + + DecIndent(); + DoIndent(); + os_ << "}"; + } +} +void IrPrinter::Visit(const Block *x) { + os_ << "{\n"; + + IncIndent(); + for (std::size_t i = 0; !x->stmts.empty() && i + 1 < x->stmts.size(); i++) { + DoIndent(); + Print(x->stmts[i]); + os_ << "\n"; + } + if (!x->stmts.empty()) { + DoIndent(); + Print(x->stmts.back()); + } + DecIndent(); + os_ << "\n"; + DoIndent(); + os_ << "}"; +} +void IrPrinter::Visit(const Call *x) { + os_ << x->name << "("; + if (!x->read_args.empty()) { + for (std::size_t i = 0; i + 1 < x->read_args.size(); i++) { + Print(x->read_args[i]); + os_ << ", "; + } + Print(x->read_args.back()); + } + + if (!x->write_args.empty()) { + if (!x->read_args.empty()) os() << ", "; + + for (std::size_t i = 0; i + 1 < x->write_args.size(); i++) { + Print(x->write_args[i]); + os_ << ", "; + } + Print(x->write_args.back()); + } + + os_ << ")"; +} +void IrPrinter::Visit(const Cast *x) { + os() << x->type(); + os() << "("; + os() << x->v(); + os() << ")"; +} +void IrPrinter::Visit(const _Module_ *x) {} +void IrPrinter::Visit(const _Var_ *x) { os_ << x->name; } +void IrPrinter::Visit(const Alloc *x) { + auto *buffer = x->destination.As(); + CHECK(buffer); + os_ << "alloc(" << buffer->name << ", "; + Print(x->extents); + os_ << ")"; +} +void IrPrinter::Visit(const Select *x) { + os_ << "select("; + Print(x->condition); + os_ << ", "; + Print(x->true_value); + os_ << ", "; + Print(x->false_value); + os_ << ")"; +} +void IrPrinter::Visit(const Load *x) { + if (x->is_addr_tensor()) { + auto *tensor = x->tensor.As(); + CHECK(tensor); + os_ << tensor->name; + } else if (x->is_addr_scalar()) { + Print(x->tensor); + } else { + CINN_NOT_IMPLEMENTED + } + + os_ << "["; + for (std::size_t i = 0; i + 1 < x->indices.size(); i++) { + Print(x->indices[i]); + os() << ", "; + } + if (!x->indices.empty()) Print(x->indices.back()); + os_ << "]"; +} +void IrPrinter::Visit(const Store *x) { + if (x->is_addr_tensor()) { + auto *tensor_node = x->tensor.As(); + CHECK(tensor_node); + os_ << tensor_node->name; + } else if (x->is_addr_scalar()) { + Print(x->tensor); + } else { + CINN_NOT_IMPLEMENTED + } + + os_ << "["; + for (std::size_t i = 0; i + 1 < x->indices.size(); i++) { + Print(x->indices[i]); + os() << ", "; + } + if (!x->indices.empty()) Print(x->indices.back()); + os_ << "] = "; + Print(x->value); +} +void IrPrinter::Visit(const Free *x) { + auto *buffer = x->destination.As(); + CHECK(buffer); + os_ << "free(" << buffer->name << ")"; +} + +void IrPrinter::DoIndent() { os_ << std::string(indent_, ' '); } +void IrPrinter::IncIndent() { indent_ += indent_unit; } +void IrPrinter::DecIndent() { indent_ -= indent_unit; } + +void IrPrinter::Visit(const _Buffer_ *x) { + std::vector dim_names; + std::transform(x->shape.begin(), x->shape.end(), std::back_inserter(dim_names), [&](const Expr &x) { + return utils::GetStreamCnt(x); + }); + + os_ << "_Buffer_<" << x->type() << ": " << utils::Join(dim_names, ",") << ">(" << x->name << ")"; +} +void IrPrinter::Visit(const _Tensor_ *x) { + os_ << "Tensor("; + os() << x->name << ", "; + os() << "["; + if (!x->shape.empty()) { + for (std::size_t i = 0; i + 1 < x->shape.size(); i++) { + Print(x->shape[i]); + os() << ","; + } + Print(x->shape.back()); + } + os_ << "])"; +} +void IrPrinter::Visit(const _LoweredFunc_ *f) { + os_ << "function " << f->name << " "; + + std::vector arg_names; + for (auto &arg : f->args) { + arg_names.push_back(arg.name()); + } + os_ << "(" << utils::Join(arg_names, ", ") << ")\n"; + + Print(f->body); +} +void IrPrinter::Visit(const Let *f) { + CHECK(f->type().valid()); + os() << f->type() << " "; + Print(f->symbol); + if (f->body.defined()) { + os() << " = "; + Print(f->body); + } +} + +void IrPrinter::Visit(const Reduce *f) { + os() << "Reduce("; + switch (f->reduce_type) { + case Reduce::ReduceType::kSum: + os() << "sum"; + break; + case Reduce::ReduceType::kSub: + os() << "sub"; + break; + case Reduce::ReduceType::kDiv: + os() << "Div"; + break; + case Reduce::ReduceType::kMul: + os() << "Mul"; + break; + case Reduce::ReduceType::kMax: + os() << "Max"; + break; + case Reduce::ReduceType::kMin: + os() << "Min"; + break; + case Reduce::ReduceType::kAll: + os() << "&&"; + break; + case Reduce::ReduceType::kAny: + os() << "||"; + break; + } + os() << ", "; + Print(f->body); + os() << ","; + Print(f->init); + os() << ")"; +} + +void IrPrinter::Visit(const Ramp *x) { + os() << "Ramp("; + Print(x->base); + os() << ","; + Print(x->stride); + os() << ","; + os() << x->lanes; + os() << ")"; +} + +void IrPrinter::Visit(const Broadcast *x) { + os() << "Broadcast("; + Print(x->value); + os() << ","; + os() << x->lanes; + os() << ")"; +} + +void IrPrinter::Visit(const FracOp *x) { + os() << "("; + Print(x->a()); + os() << " / "; + Print(x->b()); + os() << ")"; +} + +void IrPrinter::Visit(const Product *x) { + os() << "("; + for (std::size_t i = 0; i + 1 < x->operands().size(); i++) { + Print(x->operand(i)); + os() << " * "; + } + if (!x->operands().empty()) Print(x->operands().back()); + os() << ")"; +} + +void IrPrinter::Visit(const Sum *x) { + os() << "("; + for (std::size_t i = 0; i + 1 < x->operands().size(); i++) { + Print(x->operand(i)); + os() << " + "; + } + if (!x->operands().empty()) Print(x->operands().back()); + os() << ")"; +} + +void IrPrinter::Visit(const PrimitiveNode *x) { + os() << x->name << "("; + std::vector args_repr; + for (auto &args : x->arguments) { + std::vector arg_repr; + for (auto &arg : args) { + arg_repr.push_back(utils::GetStreamCnt(arg)); + } + args_repr.push_back(utils::Join(arg_repr, ",")); + } + + os() << utils::Join(args_repr, ","); + os() << ")"; +} + +void IrPrinter::Visit(const _BufferRange_ *x) { + auto *buffer = x->buffer.As(); + CHECK(buffer); + os() << buffer->name << "["; + for (std::size_t i = 0; i < x->ranges.size(); i++) { + if (i) os() << ", "; + auto &range = x->ranges[i]; + os() << range->name << "("; + if (range->lower_bound.defined()) { + os() << range->lower_bound << ":"; + } else { + os() << "undefined:"; + } + + if (range->upper_bound.defined()) { + os() << range->upper_bound; + } else { + os() << "undefined"; + } + os() << ")"; + } + os() << "]"; +} + +void IrPrinter::Visit(const ScheduleBlock *x) {} + +void IrPrinter::Visit(const ScheduleBlockRealize *x) { + auto *schedule_block = x->schedule_block.As(); + os() << "ScheduleBlock(" << schedule_block->name << ")\n"; + DoIndent(); + os() << "{\n"; + // print block vars and bindings + auto iter_vars = schedule_block->iter_vars; + auto iter_values = x->iter_values; + CHECK_EQ(iter_vars.size(), iter_values.size()); + IncIndent(); + if (!iter_vars.empty()) DoIndent(); + for (std::size_t i = 0; i < iter_vars.size(); i++) { + if (i) os() << ", "; + os() << iter_vars[i]->name; + } + if (!iter_vars.empty()) os() << " = axis.bind("; + for (std::size_t i = 0; i < iter_values.size(); i++) { + if (i) os() << ", "; + os() << iter_values[i]; + } + if (!iter_vars.empty()) os() << ")\n"; + // print block body + if (!schedule_block->read_buffers.empty()) { + DoIndent(); + os() << "read_buffers("; + auto &read_buffers = schedule_block->read_buffers; + for (std::size_t i = 0; i < read_buffers.size(); i++) { + if (i) os() << ", "; + Print(read_buffers[i]); + } + os() << ")\n"; + } + if (!schedule_block->write_buffers.empty()) { + DoIndent(); + os() << "write_buffers("; + auto &write_buffers = schedule_block->write_buffers; + for (std::size_t i = 0; i < write_buffers.size(); i++) { + if (i) os() << ", "; + Print(write_buffers[i]); + } + os() << ")\n"; + } + if (!schedule_block->attrs.empty()) { + DoIndent(); + os() << "attrs("; + bool comma = false; + for (auto &&kv : schedule_block->attrs) { + if (comma) os() << ", "; + os() << kv.first << ":"; + absl::visit([this](auto &&arg) { this->os() << arg; }, kv.second); + comma = true; + } + os() << ")\n"; + } + DoIndent(); + Print(schedule_block->body); + os() << "\n"; + DecIndent(); + DoIndent(); + os() << "}"; +} + +void IrPrinter::Visit(const IntrinsicOp *x) { + switch (x->getKind()) { +#define __(op__) \ + case IntrinsicKind::k##op__: \ + Visit(llvm::dyn_cast(x)); \ + break; + + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } +} +void IrPrinter::Visit(const intrinsics::BufferGetDataHandle *x) { + os() << runtime::intrinsic::buffer_get_data_handle; + Print(x->buffer); + os() << ")"; +} +void IrPrinter::Visit(const intrinsics::BufferGetDataConstHandle *x) { + os() << runtime::intrinsic::buffer_get_data_const_handle; + Print(x->buffer); + os() << ")"; +} +void IrPrinter::Visit(const intrinsics::PodValueToX *x) { + os() << "pod_value_to_"; + os() << x->GetOutputType(0); + os() << "("; + Print(x->pod_value_ptr); + os() << ")"; +} +void IrPrinter::Visit(const intrinsics::BufferCreate *x) { + os() << runtime::intrinsic::buffer_create; + os() << "()"; +} +void IrPrinter::Visit(const intrinsics::GetAddr *x) { + os() << "get_addr("; + Print(x->data); + os() << ")"; +} +void IrPrinter::Visit(const intrinsics::ArgsConstruct *x) { + os() << runtime::intrinsic::args_construct_repr; + os() << "("; + Print(std::vector(x->args.begin(), x->args.end())); + os() << ")"; +} + +void IrPrinter::Visit(const intrinsics::BuiltinIntrin *x) { + os_ << runtime::intrinsic::builtin_intrin_repr << "_"; + os_ << x->name << "("; + if (!x->args.empty()) { + for (std::size_t i = 0; i + 1 < x->args.size(); i++) { + Print(x->args[i]); + os_ << ", "; + } + Print(x->args.back()); + } + + os_ << ")"; +} + +std::ostream &operator<<(std::ostream &os, Expr a) { + std::stringstream ss; + IrPrinter printer(ss); + printer.Print(a); + os << ss.str(); + return os; +} + +std::ostream &operator<<(std::ostream &os, const std::vector &a) { + std::stringstream ss; + IrPrinter printer(ss); + printer.Print(a); + os << ss.str(); + return os; +} + +std::ostream &operator<<(std::ostream &os, const ir::Module &m) { + os << "Module " << m->name << " {\n\n"; + for (auto &fn : m->functions) { + os << fn << '\n'; + } + os << "\n\n}"; + return os; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_printer.h b/paddle/cinn/ir/ir_printer.h new file mode 100644 index 0000000000000..7eafbcf97172e --- /dev/null +++ b/paddle/cinn/ir/ir_printer.h @@ -0,0 +1,80 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_visitor.h" + +namespace cinn { + +namespace lang { +class LoweredFunc; +} // namespace lang + +namespace ir { +class Module; + +struct IrPrinter : public IRVisitor { + explicit IrPrinter(std::ostream &os) : os_(os) {} + + //! Emit an expression on the output stream. + void Print(Expr e); + //! Emit a expression list with , splitted. + void Print(const std::vector &exprs, const std::string &splitter = ", "); + //! Emit a binary operator + template + void PrintBinaryOp(const std::string &op, const BinaryOpNode *x); + + //! Prefix the current line with `indent_` spaces. + void DoIndent(); + //! Increase the indent size. + void IncIndent(); + //! Decrease the indent size. + void DecIndent(); + + std::ostream &os() { return os_; } + +#define __(op__) void Visit(const op__ *x) override; + NODETY_FORALL(__) +#undef __ + +#define __(op__) virtual void Visit(const intrinsics::op__ *x); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + + private: + std::ostream &os_; + uint16_t indent_{}; + const int indent_unit{2}; +}; + +std::ostream &operator<<(std::ostream &os, Expr a); +std::ostream &operator<<(std::ostream &os, const std::vector &a); +std::ostream &operator<<(std::ostream &os, const Module &m); + +template +void IrPrinter::PrintBinaryOp(const std::string &op, const BinaryOpNode *x) { + os_ << "("; + Print(x->a()); + os_ << " " + op + " "; + Print(x->b()); + os_ << ")"; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_printer_test.cc b/paddle/cinn/ir/ir_printer_test.cc new file mode 100644 index 0000000000000..1f9edca6ded05 --- /dev/null +++ b/paddle/cinn/ir/ir_printer_test.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_printer.h" + +#include + +#include + +namespace cinn { +namespace ir {} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule.cc b/paddle/cinn/ir/ir_schedule.cc new file mode 100644 index 0000000000000..eb2d934e0f646 --- /dev/null +++ b/paddle/cinn/ir/ir_schedule.cc @@ -0,0 +1,2310 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_schedule.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule_util.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/lang/compute.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/replace_var_with_expr.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +/** + * A struct helps to implement Schedule primitives. + */ +class ScheduleImpl { + public: + ScheduleImpl() = default; + explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false) + : module_expr_(module_expr), debug_flag_(debug_flag) {} + explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {} + + //! Set the debug flag. + void SetDebugFlag(bool debug_flag) { debug_flag_ = debug_flag; } + + //! Get the ModuleExpr stored in ScheduleImpl. + const ModuleExpr& GetModule() const { return module_expr_; } + + void MergeExprs(); + + void SetExprs(const std::vector& exprs) { module_expr_.SetExprs(exprs); } + + bool HasBlock(const std::string& block_name) const; + + std::vector GetLoops(const Expr& block) const; + std::vector GetLoops(const std::string& block_name) const; + std::vector GetAllBlocks() const; + std::vector GetChildBlocks(const Expr& expr) const; + Expr GetBlock(const std::string& block_name) const; + std::vector Split(const Expr& loop, const std::vector& factors); + std::vector SamplePerfectTile(utils::LinearRandomEngine::StateType* rand_seed, + const Expr& loop, + int n, + int max_innermost_factor); + Expr Fuse(const std::vector& loops); + Expr Fuse(const std::string& block_name, const std::vector& loops_index); + Expr Fuse(const Expr& block, const std::vector& loops_index); + void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops); + void SimpleComputeAt(const Expr& block, const Expr& loop); + void ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops); + Expr GetRootBlock(const Expr& expr) const; + Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); + Expr CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type); + void SyncThreads(const Expr& ir_node, bool after_node = true); + void SetBuffer(Expr& block, const std::string& memory_type, bool fixed = false); + Expr Reorder(const std::vector& loops); + Expr Reorder(const std::string& block_name, const std::vector& loops_index); + Expr Reorder(const Expr& block, const std::vector& loops_index); + DeviceAPI GetDeviceAPI() const; + void MutateForType(const Expr& loop, ForType for_type, int factor = -1); + void Parallel(const Expr& loop); + void Vectorize(const Expr& loop, int factor); + void Unroll(const Expr& loop); + void ComputeInline(const Expr& schedule_block); + void ReverseComputeInline(const Expr& schedule_block); + void Bind(const Expr& loop, const std::string& thread_axis); + Expr Rfactor(const Expr& rf_loop, int rf_axis); + Expr AddUnitLoop(const Expr& block) const; + void Annotate(const Expr& block, const std::string& key, const attr_t& value); + void Unannotate(Expr& block, const std::string& key); + void FlattenLoops(const std::vector& loops, const bool force_flat = false); + void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target); + void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name); + Expr SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed, + const std::vector& candidates, + const std::vector& probs); + + private: + void Replace(const Expr& src_sref, const Expr& tgt_stmt); + + ModuleExpr module_expr_; + bool debug_flag_{false}; +}; + +std::vector ScheduleImpl::Split(const Expr& loop, const std::vector& factors) { + CHECK(loop.As()) << "Expr param of Split must be For node! Please check."; + auto* for_node = loop.As(); + CHECK(common::is_zero(for_node->min)) << "The For node must start with 0! Please check."; + CHECK(for_node->extent.is_constant()) << "The For node's extent must be constant! Please check."; + int tot_extent = for_node->extent.get_constant(); + + VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to (" + << cinn::utils::Join(factors, ", ") << ") at loop:\n" + << loop; + + auto processed_factors = ValidateFactors(factors, tot_extent); + int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies()); + std::vector new_loop_vars; + Expr substitute_value(0); + for (int i = 0; i < processed_factors.size(); ++i) { + Var temp_var(common::UniqName(for_node->loop_var->name)); + substitute_value = Expr(temp_var) + substitute_value * Expr(processed_factors[i]); + new_loop_vars.push_back(temp_var); + } + substitute_value = common::AutoSimplify(substitute_value); + Expr new_node = optim::IRCopy(for_node->body); + ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value}); + std::vector splited_loops; + splited_loops.resize(processed_factors.size()); + if (tot_extent < prod_size) { + new_node = IfThenElse::Make(LT::Make(substitute_value, for_node->extent), new_node); + } + for (int i = processed_factors.size() - 1; i >= 0; i--) { + if (!new_node.As()) new_node = Block::Make({new_node}); + new_node = For::Make( + new_loop_vars[i], Expr(0), Expr(processed_factors[i]), for_node->for_type(), for_node->device_api, new_node); + splited_loops[i] = new_node; + } + + this->Replace(loop, new_node); + VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); + return splited_loops; +} + +Expr ScheduleImpl::Fuse(const std::vector& loops) { + VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); + std::vector for_nodes; + std::vector loop_vars; + CHECK(!loops.empty()) << "The loops param of Fuse should not be empty! Please check."; + + for (const Expr& it_loop : loops) { + CHECK(it_loop.As()) << "Expr param of Fuse must be For node! Please check."; + if (!for_nodes.empty()) { + CHECK(for_nodes.back()->body.As()) << "The body of for node is not Block!"; + CHECK_EQ(for_nodes.back()->body.As()->stmts.size(), 1U) << "The Block'size of for node is not 1!"; + CHECK_EQ(for_nodes.back()->body.As()->stmts[0], it_loop) + << "The For nodes in loops param of Fuse must be adjacent! Please check."; + } + for_nodes.push_back(it_loop.As()); + loop_vars.push_back(it_loop.As()->loop_var); + } + std::string suffix; + suffix = for_nodes[0]->loop_var->name; + int loops_number = for_nodes.size(); + for (int i = 1; i < loops_number; ++i) { + suffix += "_" + for_nodes[i]->loop_var->name; + } + suffix += "_fused"; + Var fused_var(suffix); + std::vector substitute_value; + substitute_value.resize(loops_number); + Expr fused_expr(fused_var); + for (int i = loops_number - 1; i > 0; i--) { + substitute_value[i] = Mod::Make(fused_expr, for_nodes[i]->extent); + fused_expr = Div::Make(fused_expr, for_nodes[i]->extent); + } + substitute_value[0] = fused_expr; + + Expr fused_body = optim::IRCopy(for_nodes.back()->body); + ReplaceExpr(&fused_body, loop_vars, substitute_value); + optim::Simplify(&fused_body); + Expr fused_extent(1); + for (int i = 0; i < loops_number; ++i) { + fused_extent = fused_extent * for_nodes[i]->extent; + } + fused_extent = common::AutoSimplify(fused_extent); + + if (!fused_body.As()) fused_body = Block::Make({fused_body}); + Expr new_stmt = + For::Make(fused_var, Expr(0), fused_extent, for_nodes[0]->for_type(), for_nodes[0]->device_api, fused_body); + this->Replace(loops[0], new_stmt); + + VLOG(3) << "After fuse, ir is:\n" << new_stmt; + return new_stmt; +} + +Expr ScheduleImpl::Fuse(const std::string& block_name, const std::vector& loops_index) { + std::vector all_loops = this->GetLoops(block_name); + std::vector loops_expr; + loops_expr.reserve(loops_index.size()); + for (int i = 0; i < loops_index.size(); ++i) { + if (i > 0) CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) << "Loops index in Fuse shoule be continuous!"; + } + for (int i : loops_index) { + CHECK_LT(i, (int)all_loops.size()) << "The loop index in Fuse should be less than total loop's number."; + CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; + loops_expr.emplace_back(all_loops[i]); + } + return this->Fuse(loops_expr); +} + +Expr ScheduleImpl::Fuse(const Expr& block, const std::vector& loops_index) { + std::vector all_loops = this->GetLoops(block); + std::vector loops_expr; + loops_expr.reserve(loops_index.size()); + for (int i = 0; i < loops_index.size(); ++i) { + if (i > 0) CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) << "Loops index in Fuse shoule be continuous!"; + } + for (int i : loops_index) { + CHECK_LT(i, (int)all_loops.size()) << "The loop index in Fuse should be less than total loop's number."; + CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; + loops_expr.emplace_back(all_loops[i]); + } + return this->Fuse(loops_expr); +} + +void ScheduleImpl::MutateForType(const Expr& loop, ForType for_type, int factor) { + auto* for_node = loop.As(); + CHECK(for_node) << "loop param must be For node! Please check."; + CHECK(for_node->is_serial()) << "loop is not serial, current forloop type is " + << static_cast(for_node->for_type()) << ", and it cannot become " + << static_cast(for_type); + auto loop_copy = optim::IRCopy(loop); + auto* new_for_node = loop_copy.As(); + CHECK(new_for_node); + new_for_node->set_for_type(for_type); + if (new_for_node->is_vectorized()) { + VectorizeInfo vec_info(0, factor); + new_for_node->set_vectorize_info(vec_info); + } else if (new_for_node->is_binded()) { + BindInfo bind_info(for_type, factor, DeviceAPI::GPU); + new_for_node->set_bind_info(bind_info); + } + this->Replace(loop, loop_copy); +} + +void ScheduleImpl::Parallel(const Expr& loop) { MutateForType(loop, ForType::Parallel); } + +void ScheduleImpl::Vectorize(const Expr& loop, int factor) { + CHECK_GT(factor, 0) << "vectorize factor should be more than 0"; + MutateForType(loop, ForType::Vectorized, factor); +} + +void ScheduleImpl::Unroll(const Expr& loop) { MutateForType(loop, ForType::Unrolled); } + +void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) { + static std::set thread_axes = { + "blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "threadIdx.y", "threadIdx.z"}; + CHECK(thread_axes.count(thread_axis)) << "thread_axis " << thread_axis << " is not supported"; + int offset = thread_axis.back() - 'x'; + if (thread_axis[0] == 'b') { + MutateForType(loop, ForType::GPUBlock, offset); + } else { + MutateForType(loop, ForType::GPUThread, offset); + } +} + +// The struct used to mutate new rfactor forloop and its' schedule block. +struct RfMutator : public ir::IRMutator<> { + public: + RfMutator(const Expr& rf_loop, const int& rf_axis) : rf_loop_(rf_loop), rf_axis_(rf_axis) {} + void operator()(Expr* expr) { + auto* rf_for = rf_loop_.As(); + CHECK(rf_for); + old_rf_loop_var_ = rf_for->loop_var; + new_rf_loop_var_ = Var("rf_" + old_rf_loop_var_->name); + IRMutator::Visit(expr, expr); + } + + Tensor GetNewRfTensor() { return new_rf_tensor_; } + + void Visit(const ScheduleBlockRealize* op, Expr* expr) override { + // modify iter_vars and iter_values + auto* node = expr->As(); + CHECK(node); + auto* schedule_block = node->schedule_block.As(); + CHECK(schedule_block); + old_output_name_ = schedule_block->name; + find_tensor_ = false; + auto& block_vars = schedule_block->iter_vars; + auto& iter_values = node->iter_values; + CHECK(old_rf_loop_var_.defined()); + CHECK(new_rf_loop_var_.defined()); + CHECK_EQ(iter_values.size(), block_vars.size()); + int rf_index = -1; + for (int i = 0; i < iter_values.size(); ++i) { + // substitute the old rfactor loop var to new rfactor loop var + if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) { + CHECK_EQ(rf_index, -1) << "only one block var can bind the rfactor loop var"; + CHECK(iter_values[i].As<_Var_>()) << "rfactor loop var not support composite bindings"; + rf_index = i; + optim::ReplaceVarWithExpr(&iter_values[i], old_rf_loop_var_, new_rf_loop_var_); + new_rf_itervar_ = block_vars[i]; + } + } + // create new rfactor block var if not exist + if (rf_index == -1) { + new_rf_itervar_ = Var(cinn::UniqName("i" + std::to_string(block_vars.size()))); + iter_values.push_back(new_rf_loop_var_); + block_vars.push_back(new_rf_itervar_); + } + IRMutator::Visit(&node->schedule_block, &node->schedule_block); + CHECK(find_tensor_) << "not find the store tensor with the schedule block name " << old_output_name_; + schedule_block->name = "rf_" + old_output_name_; + } + + void Visit(const Load* op, Expr* expr) override { + // insert the new rfactor indice if not exist + auto* node = expr->As(); + CHECK(node); + auto* tensor = node->tensor.As<_Tensor_>(); + CHECK(tensor); + if (tensor->name == "rf_" + old_output_name_) { + int size = node->indices.size(); + CHECK_LE(rf_axis_, size) << "rf_axis should not be greater than indice size " << size; + CHECK(new_rf_itervar_.defined()); + CHECK(!ContainVar(node->indices, new_rf_itervar_->name)) + << "original output tensor " << old_output_name_ << " should not have the new rfactor index " + << new_rf_itervar_; + node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_); + } + } + + void Visit(const Store* op, Expr* expr) override { + // insert the new rfactor indice if not exist + auto* node = expr->As(); + CHECK(node); + auto* tensor = node->tensor.As<_Tensor_>(); + CHECK(tensor); + if (tensor->name == old_output_name_) { + find_tensor_ = true; + tensor->name = "rf_" + tensor->name; + int size = node->indices.size(); + CHECK_LE(rf_axis_, size) << "rf_axis should not be greater than indice size " << size; + CHECK(!ContainVar(node->indices, new_rf_itervar_->name)) + << "original output tensor " << old_output_name_ << " should not have the new rfactor index " + << new_rf_itervar_; + node->indices.insert(node->indices.begin() + rf_axis_, new_rf_itervar_); + auto* rf_for = rf_loop_.As(); + CHECK(rf_for); + CHECK(is_zero(rf_for->min)) << "rfactor loop's min should be zero"; + auto extent = common::AutoSimplify(rf_for->extent); + auto& shape = tensor->shape; + auto& domain = tensor->domain; + CHECK_LE(rf_axis_, shape.size()) << "rf_axis should not be greater than tensor shape size " << shape.size(); + CHECK_LE(rf_axis_, domain.size()) << "rf_axis should not be greater than tensor domain size " << domain.size(); + shape.insert(shape.begin() + rf_axis_, extent); + domain.insert(domain.begin() + rf_axis_, extent); + if (tensor->buffer.defined()) { + if (tensor->buffer->name.find_first_of("rf") == std::string::npos) { + tensor->buffer->name = "rf_" + tensor->buffer->name; + tensor->buffer->shape = shape; + } + } + new_rf_tensor_ = Tensor(tensor); + } + IRMutator::Visit(&node->value, &node->value); + } + + void Visit(const For* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + depth++; + auto* rf_for = rf_loop_.As(); + CHECK(rf_for); + // erase the original rfactor forloop + if (node->loop_var->name == old_rf_loop_var_->name) { + auto body = node->body.As(); + if (body && body->stmts.size() == 1) { + *expr = body->stmts[0]; + } else { + *expr = node->body; + } + IRMutator::Visit(expr, expr); + } else { + IRMutator::Visit(&node->body, &node->body); + } + if (rf_axis_ == 0 && depth == rf_axis_) { + // insert new rfactor forloop in the rf_axis as serial loop + *expr = For::Make( + new_rf_loop_var_, rf_for->min, rf_for->extent, ForType::Serial, rf_for->device_api, Block::Make({*expr})); + } else if (depth == rf_axis_ - 1) { + // insert new rfactor forloop in the rf_axis as serial loop + node->body = Block::Make( + {For::Make(new_rf_loop_var_, rf_for->min, rf_for->extent, ForType::Serial, rf_for->device_api, node->body)}); + } + depth--; + } + + private: + Expr rf_loop_; + Var old_rf_loop_var_; + Var new_rf_loop_var_; + int rf_axis_; + int depth = -1; + bool find_tensor_ = false; + std::string old_output_name_; + Var new_rf_itervar_; + Tensor new_rf_tensor_; +}; + +// The struct used to mutate final write-back forloop and schedule block. +struct FinalMutator : public ir::IRMutator<> { + public: + FinalMutator(const Expr& rf_loop, const int& rf_axis, const Tensor& new_rf_tensor) + : rf_loop_(rf_loop), rf_axis_(rf_axis), new_rf_tensor_(new_rf_tensor) {} + void operator()(Expr* expr) { + auto* rf_for = rf_loop_.As(); + CHECK(rf_for); + old_rf_loop_var_ = rf_for->loop_var; + IRMutator::Visit(expr, expr); + } + + void Visit(const ScheduleBlockRealize* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto* schedule_block = node->schedule_block.As(); + CHECK(schedule_block); + auto& iter_vars = schedule_block->iter_vars; + auto& iter_values = node->iter_values; + output_name_ = schedule_block->name; + visit_init_block_ = output_name_.rfind("_init") != std::string::npos; + if (!visit_init_block_) { + for (int i = 0; i < iter_values.size(); ++i) { + if (ContainVar({iter_values[i]}, old_rf_loop_var_->name)) { + // record the rfactor loop var's block var + CHECK(iter_values[i].As<_Var_>()) << "not support complex reduce bindings: " << iter_values[i]; + old_rf_iter_var_ = iter_vars[i]; + break; + } + } + } + IRMutator::Visit(&node->schedule_block, &node->schedule_block); + // modify iter_vars and iter_values, erase other reduce block vars and values + for (auto it = iter_values.begin(); it != iter_values.end(); ++it) { + for (auto erase_var : erase_reduce_loopvars_) { + if (ContainVar({*it}, erase_var)) { + CHECK((*it).As<_Var_>()) << "not support complex reduce bindings: " << *it; + iter_vars.erase(it - iter_values.begin() + iter_vars.begin()); + iter_values.erase(it); + --it; + break; + } + } + } + } + + // currently only support reduce_sum, reduce_mul, reduce_min and reduce_max + void Visit(const Add* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto& oper_b = node->b(); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + } + + void Visit(const Mul* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto& oper_b = node->b(); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + } + + void Visit(const Min* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto& oper_b = node->b(); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + } + + void Visit(const Max* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto& oper_b = node->b(); + oper_b = Load::Make(new_rf_tensor_, new_rf_indice_); + } + + void Visit(const Store* op, Expr* expr) override { + // insert the new rfactor indice if not exist + auto* node = expr->As(); + CHECK(node); + auto* tensor = node->tensor.As<_Tensor_>(); + CHECK(tensor); + CHECK_EQ(tensor->name, output_name_) << "store name should be same with the schedule block name"; + if (!visit_init_block_) { + new_rf_indice_ = node->indices; + CHECK_LE(rf_axis_, new_rf_indice_.size()) + << "rf_axis_ should not be greater than tensor indice size " << new_rf_indice_.size(); + CHECK(old_rf_iter_var_.defined()); + new_rf_indice_.insert(new_rf_indice_.begin() + rf_axis_, old_rf_iter_var_); + IRMutator::Visit(&node->value, &node->value); + } + } + + void Visit(const For* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto* rf_for = rf_loop_.As(); + // erase the reduce forloops after the init block except the rfactor loop + if (visit_init_block_ && node->loop_var->name != old_rf_loop_var_->name) { + erase_reduce_loopvars_.insert(node->loop_var->name); + auto body = node->body.As(); + if (body && body->stmts.size() == 1) { + *expr = body->stmts[0]; + } else { + *expr = node->body; + } + IRMutator::Visit(expr, expr); + } else { + IRMutator::Visit(&node->body, &node->body); + } + } + + private: + Expr rf_loop_; + int rf_axis_; + Var old_rf_loop_var_; + Var old_rf_iter_var_; + std::string output_name_; + // collect reduce loop vars except rfactor loop var + std::set erase_reduce_loopvars_; + bool visit_init_block_ = false; + Tensor new_rf_tensor_; + std::vector new_rf_indice_; +}; + +// The struct used to create all stmts after rfactor transformation. +struct RfCreater : public ir::IRMutator<> { + public: + RfCreater(const Expr& root, const Expr& rf_loop, const int& rf_axis) + : root_(root), rf_loop_(rf_loop), rf_axis_(rf_axis) {} + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + Expr CreateRfAllStmts() { + auto root_realize = root_.As(); + CHECK(root_realize); + auto root_block = root_realize->schedule_block.As(); + CHECK(root_block); + Expr root_loop = optim::IRCopy(root_block->body); + if (auto block = root_loop.As()) { + CHECK_EQ(block->stmts.size(), 1U) << "rfactor root should only have one block stmt"; + root_loop = block->stmts[0]; + } + auto* root_for = root_loop.As(); + CHECK(root_for); + auto rf_for = rf_loop_.As(); + CHECK(rf_for); + // create new rfactor forloops + Expr new_rf_forloop = optim::IRCopy(root_loop); + RfMutator rf_mutator(rf_loop_, rf_axis_); + rf_mutator(&new_rf_forloop); + VLOG(3) << "After RfMutator, new rf_forloop is\n" << new_rf_forloop; + auto new_rf_tensor = rf_mutator.GetNewRfTensor(); + // create final write-back forloops + Expr final_forloop = optim::IRCopy(root_loop); + FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor); + final_mutator(&final_forloop); + VLOG(3) << "After FinalMuator, final write-back forloop is\n" << final_forloop; + // combine the new created rfactor forloops with the final write-back forloops and replace + root_block->body = Block::Make({new_rf_forloop, final_forloop}); + return new_rf_tensor; + } + + Expr root_; + Expr rf_loop_; + int rf_axis_; +}; + +Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { + CHECKRfactorValidation(rf_loop, rf_axis); + // get root ScheduleBlockRealize + Expr root = GetRootBlock(rf_loop); + // create all stmts after rfactor transformation + RfCreater rf_create(root, rf_loop, rf_axis); + // return new created rfactor tensor + return rf_create.CreateRfAllStmts(); +} + +struct CacheReadRewriter : public ir::IRMutator<> { + public: + static Expr Rewrite(const Expr& root, CacheBlockInfo* info) { + CacheReadRewriter rewriter(root, info); + Expr new_root = optim::IRCopy(root); + rewriter(&new_root); + return new_root; + } + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + explicit CacheReadRewriter(const Expr& root, CacheBlockInfo* info) : root_(root), info_(info) {} + + void Visit(const ir::Block* expr, Expr* op) override { + if (*op == info_->loc_block) { + IRMutator::Visit(expr, op); + op->As()->stmts.insert(op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); + } else { + IRMutator::Visit(expr, op); + } + } + + void Visit(const ir::Load* expr, Expr* op) override { + if (expr->tensor == Expr(info_->read_tensor)) { + IRMutator::Visit(expr, op); + op->As()->tensor = Expr(info_->write_tensor); + } else { + IRMutator::Visit(expr, op); + } + } + + private: + /*! \brief The parent scope of the insertion */ + const Expr& root_; + /*! \brief The info for inserting cache stage */ + CacheBlockInfo* info_; +}; + +struct CacheWriteRewriter : public ir::IRMutator<> { + public: + static Expr Rewrite(const Expr& root, CacheBlockInfo* info) { + CacheWriteRewriter rewriter(root, info); + Expr new_root = optim::IRCopy(root); + rewriter.mutate_cache_block = true; + rewriter(&info->cache_block); + rewriter.mutate_cache_block = false; + rewriter(&new_root); + auto find_tensor = ir::CollectIRNodesWithoutTensor( + new_root, + [&](const Expr* x) { return x->As() && (x->As()->tensor == Expr(info->read_tensor)); }, + true); + if (!find_tensor.empty()) { + auto find_store = ir::CollectIRNodesWithoutTensor((*find_tensor.begin()), [&](const Expr* x) { + return x->As() && (x->As()->tensor == Expr(info->write_tensor)); + }); + for (auto load_ir : find_store) { + load_ir.As()->tensor = Expr(info->read_tensor); + } + } + return new_root; + } + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + explicit CacheWriteRewriter(const Expr& root, CacheBlockInfo* info) : root_(root), info_(info) {} + + void Visit(const ir::Block* expr, Expr* op) override { + if (*op == info_->loc_block) { + IRMutator::Visit(expr, op); + op->As()->stmts.insert(op->As()->stmts.begin() + info_->loc_pos, info_->cache_block); + } else { + IRMutator::Visit(expr, op); + } + } + + void Visit(const ir::ScheduleBlock* expr, Expr* op) override { + if (op->As()->name == info_->write_tensor->name) { + op->As()->name = info_->read_tensor->name; + } else if (op->As()->name == info_->read_tensor->name) { + op->As()->name = info_->write_tensor->name; + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::Load* expr, Expr* op) override { + IRMutator::Visit(expr, op); + if (op->As()->tensor == Expr(info_->write_tensor) && mutate_cache_block) { + op->As()->tensor = Expr(info_->read_tensor); + } else if (op->As()->tensor == Expr(info_->read_tensor) && mutate_cache_block) { + op->As()->tensor = Expr(info_->write_tensor); + } + } + + void Visit(const ir::Store* expr, Expr* op) override { + IRMutator::Visit(expr, op); + if (op->As()->tensor == Expr(info_->write_tensor)) { + op->As()->tensor = Expr(info_->read_tensor); + } else if (op->As()->tensor == Expr(info_->read_tensor) && mutate_cache_block) { + op->As()->tensor = Expr(info_->write_tensor); + } + } + + private: + /*! \brief The parent scope of the insertion */ + const Expr& root_; + /*! \brief The info for inserting cache stage */ + CacheBlockInfo* info_; + /*! \brief Are we mutating the cache tensor's block */ + bool mutate_cache_block{true}; +}; + +//! Visit all ScheduleBlock and change its body to ir::Block if it is not. +struct ChangeBodyToBlock : public ir::IRMutator<> { + public: + static void Change(Expr* expr) { + ChangeBodyToBlock mutator; + mutator(expr); + } + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlock* expr, Expr* op) override { + if (!op->As()->body.As()) { + op->As()->body = Block::Make({op->As()->body}); + } + IRMutator::Visit(expr, op); + } +}; + +DeviceAPI ScheduleImpl::GetDeviceAPI() const { + auto exprs = this->GetModule().GetExprs(); + auto find_for_nodes = ir::CollectIRNodesWithoutTensor( + exprs.front(), [&](const Expr* x) { return x->As(); }, true); + CHECK(!find_for_nodes.empty()); + return (*find_for_nodes.begin()).As()->device_api; +} + +Expr ScheduleImpl::CacheRead(const Expr& block, int read_tensor_index, const std::string& memory_type) { + CHECK(block.As()); + auto root = GetRootBlock(block); + ChangeBodyToBlock::Change(&root); + Expr read_expr = GetNthAccessExpr(block, read_tensor_index, false); + CHECK(read_expr.As()); + auto tensor_indices = read_expr.As()->indices; + CacheBlockInfo info; + info.read_tensor = read_expr.As()->tensor.as_tensor_ref(); + info.write_tensor = MakeCacheTensor(info.read_tensor, memory_type); + info.alloc = info.write_tensor; + + auto read_ranges = CalculateTensorRegions(block, tensor_indices, info.read_tensor, root); + auto new_block = MakeCacheBlock(read_ranges, &info, memory_type, this->GetDeviceAPI()); + FindInsertionPoint(root, &info, false); + auto new_root = CacheReadRewriter::Rewrite(root, &info); + this->Replace(root.As()->schedule_block.As()->body, + new_root.As()->schedule_block.As()->body); + return new_block; +} + +Expr ScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type) { + CHECK(block.As()); + auto root = GetRootBlock(block); + ChangeBodyToBlock::Change(&root); + Expr write_expr = GetNthAccessExpr(block, write_buffer_index, true); + CHECK(write_expr.As()); + Tensor write_tensor = write_expr.As()->tensor.as_tensor_ref(); + auto tensor_indices = write_expr.As()->indices; + CacheBlockInfo info; + info.read_tensor = MakeCacheTensor(write_tensor, memory_type); + info.write_tensor = write_tensor; + info.alloc = info.read_tensor; + auto write_ranges = CalculateTensorRegions(block, tensor_indices, info.write_tensor, root); + auto new_block = MakeCacheBlock(write_ranges, &info, memory_type, this->GetDeviceAPI()); + FindInsertionPoint(root, &info, true); + + auto new_root = CacheWriteRewriter::Rewrite(root, &info); + this->Replace(root.As()->schedule_block.As()->body, + new_root.As()->schedule_block.As()->body); + + auto find_cache_block = ir::CollectIRNodesWithoutTensor( + root, + [&](const Expr* x) { + return x->As() && !x->As()->iter_values.empty() && + GetTensor(*x)->name == info.read_tensor->name; + }, + true); + + CHECK(info.write_tensor->buffer.defined()); + + // Replace buffer + auto all_tensors = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->as_tensor() && x->as_tensor()->buffer.defined(); }); + + for (auto i : all_tensors) { + if (i.as_tensor()->name != info.write_tensor->name && i.as_tensor()->buffer.defined() && + i.as_tensor()->buffer->name == info.write_tensor->buffer->name) { + i.as_tensor()->Bind(info.read_tensor->buffer); + } + } + + CHECK_EQ(find_cache_block.size(), 1U); + + return *find_cache_block.begin(); +} + +struct InsertExpr : public ir::IRMutator<> { + public: + static void Insert(const Expr& ir_node, const Expr& insert_node, bool after_node, Expr* expr) { + InsertExpr mutator(ir_node, insert_node, after_node); + mutator(expr); + } + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + explicit InsertExpr(const Expr& ir_node, const Expr& insert_node, bool after_node) + : ir_node_(ir_node), insert_node_(insert_node), after_node_(after_node) {} + + void Visit(const ir::Block* expr, Expr* op) override { + for (int i = 0; i < expr->stmts.size(); i++) { + if (expr->stmts[i] == ir_node_) { + if (after_node_) { + op->As()->stmts.insert(op->As()->stmts.begin() + i + 1, insert_node_); + } else { + op->As()->stmts.insert(op->As()->stmts.begin() + i, insert_node_); + } + return; + } + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, Expr* op) override { + if (expr->body == ir_node_) { + if (after_node_) + op->As()->body = ir::Block::Make({op->As()->body, insert_node_}); + else + op->As()->body = ir::Block::Make({insert_node_, op->As()->body}); + return; + } + IRMutator::Visit(expr, op); + } + + private: + const Expr& ir_node_; + const Expr& insert_node_; + bool after_node_; +}; + +void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) { + CHECK(ir_node.As() || ir_node.As()); + auto root = GetRootBlock(ir_node); + ChangeBodyToBlock::Change(&root); + Expr sync_threads = runtime::IntrinsicCall(Void(), "__syncthreads", {}); + InsertExpr::Insert(ir_node, sync_threads, after_node, &root); + return; +} + +/** + * Replace a For node to another For node. + * @param src_sref The For node to be changed. + * @param tgt_stmt The For node we want. + */ +void ScheduleImpl::Replace(const Expr& src_sref, const Expr& tgt_stmt) { + CHECK(src_sref.As() || src_sref.As() || src_sref.As()); + CHECK(tgt_stmt.As() || tgt_stmt.As() || tgt_stmt.As()); + if (src_sref == tgt_stmt) { + return; + } + struct ForLoopMutator : public ir::IRMutator<> { + ForLoopMutator(const Expr& source, const Expr& target) : source_(source), target_(target) {} + + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::For* op, Expr* expr) override { + if (*expr == source_) { + *expr = target_; + return; + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { + if (*expr == source_) { + *expr = target_; + return; + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Block* op, Expr* expr) override { + if (*expr == source_) { + *expr = target_; + return; + } + ir::IRMutator<>::Visit(op, expr); + } + + const Expr& source_; + const Expr& target_; + }; + auto exprs = module_expr_.GetExprs(); + ForLoopMutator mutator(src_sref, tgt_stmt); + for (auto& i : exprs) { + mutator(&i); + } +} + +Expr ScheduleImpl::Reorder(const std::vector& loops) { + if (loops.size() <= 1) { + return Expr{nullptr}; + } + VLOG(4) << "Before Reorder, ir is:\n" << loops[0]; + + std::set loop_set = CollectLoopsToSet(loops); + auto boundary = GetBoundaryOfReorderRange(loop_set); + Expr top = boundary.first; + Expr bottom = boundary.second; + std::vector chain = GetLoopsInRange(top, bottom); + std::vector if_nodes = GetIfThenElseInRange(top, bottom); + Expr new_loop = ConstructNewLoopChain(chain, loops, loop_set, if_nodes); + this->Replace(top, new_loop); + + VLOG(4) << "After Reorder, ir is:\n" << new_loop; + return new_loop; +} + +Expr ScheduleImpl::Reorder(const std::string& block_name, const std::vector& loops_index) { + std::vector all_loops = this->GetLoops(block_name); + std::vector loops_expr; + loops_expr.reserve(loops_index.size()); + for (int i : loops_index) { + CHECK_LT(i, (int)all_loops.size()) << "The loop index in Reorder should be less than total loop's number."; + CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; + loops_expr.emplace_back(all_loops[i]); + } + return this->Reorder(loops_expr); +} + +Expr ScheduleImpl::Reorder(const Expr& block, const std::vector& loops_index) { + std::vector all_loops = this->GetLoops(block); + std::vector loops_expr; + loops_expr.reserve(loops_index.size()); + for (int i : loops_index) { + CHECK_LT(i, (int)all_loops.size()) << "The loop index in Reorder should be less than total loop's number."; + CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; + loops_expr.emplace_back(all_loops[i]); + } + return this->Reorder(loops_expr); +} + +Expr ScheduleImpl::GetRootBlock(const Expr& expr) const { + auto exprs = this->GetModule().GetExprs(); + for (auto& it_expr : exprs) { + auto find_expr = ir::CollectIRNodesWithoutTensor( + it_expr, [&](const Expr* x) { return x->node_type() == expr.node_type() && *x == expr; }, true); + if (!find_expr.empty()) { + CHECK(it_expr.As()); + CHECK_EQ(it_expr.As()->stmts.size(), 1U); + CHECK(it_expr.As()->stmts[0].As()); + return it_expr.As()->stmts[0]; + } + } + LOG(FATAL) << "Didn't find expr \n" << expr << "in ScheduleImpl:\n" << exprs[0]; +} + +// The struct used to reconstruct the new For node to replace the old For node. +struct LoopReconstructor : public ir::IRMutator<> { + public: + explicit LoopReconstructor(const Expr& root, const Expr& block, const Expr& loop) + : root_(root), block_(block), loop_(loop) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + /* \param inserted_pos The position index of the new_loop_ body `stmts` to be inserted: + * - `index = -1` means inserted into the tail + * - otherwise, it should be a index between [0, stmts size) + */ + std::string MakeNewLoop(const std::vector& iter_ranges, bool keep_unit_loops, int inserted_pos = -1) { + int n_iters = iter_ranges.size(); + std::vector loop_vars; + std::vector loop_extents; + std::vector iter_values; + loop_vars.reserve(n_iters); + loop_extents.reserve(n_iters); + iter_values.reserve(n_iters); + std::vector new_var_names; + for (int i = 0; i < n_iters; ++i) { + const auto& range = iter_ranges[i]; + if (keep_unit_loops || range.extent != Expr(1)) { + std::string var_name = common::UniqName("ax" + std::to_string(loop_vars.size())); + new_var_names.push_back(var_name); + Var var(var_name, Int(32)); + loop_vars.push_back(var); + loop_extents.push_back(range.extent); + iter_values.push_back(common::AutoSimplify(range.min) + var); + } else { + iter_values.push_back(common::AutoSimplify(range.min)); + } + } + auto schedule_block_node = block_.As()->schedule_block; + new_block_ = ScheduleBlockRealize::Make(std::move(iter_values), std::move(schedule_block_node)); + Expr loop_body = new_block_; + for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { + auto loop_var = loop_vars[i]; + auto loop_extent = loop_extents[i]; + if (!loop_body.As()) loop_body = Block::Make({loop_body}); + loop_body = For::Make( + loop_var, Expr(0), loop_extent, ForType::Serial, loop_.As()->device_api, std::move(loop_body)); + } + new_loop_ = optim::IRCopy(loop_); + + // Replace the copied Tensor object with the original Tensor object, + // to ensure that the same Tensor in a AST is the same object. + std::unordered_map tensors_map; + ir::CollectIRNodesWithoutTensor(loop_, [&tensors_map](const Expr* x) { + if (x->as_tensor()) { + tensors_map.insert({x->as_tensor()->name, *x}); + return true; + } + return false; + }); + auto find_store = ir::CollectIRNodesWithoutTensor(new_loop_, [](const Expr* x) { return x->As(); }); + for (auto store : find_store) { + store.As()->tensor = tensors_map.at(store.As()->tensor.as_tensor()->name); + } + auto find_load = ir::CollectIRNodesWithoutTensor(new_loop_, [](const Expr* x) { return x->As(); }); + for (auto load : find_load) { + load.As()->tensor = tensors_map.at(load.As()->tensor.as_tensor()->name); + } + + InsertBlock(new_loop_, loop_body, inserted_pos); + return utils::Join(new_var_names, ","); + } + + private: + public: + /*! \brief The root block */ + Expr root_; + /*! \brief The given block to be moved */ + Expr block_; + /*! \brief The given loop the block and its loop nest to be put under */ + Expr loop_; + /*! \brief The new loop to replace the original loop */ + Expr new_loop_{nullptr}; + /*! \brief The new block realize to the moved block */ + Expr new_block_{nullptr}; + /*! \brief The plan to remove the given block by replacing this loop/block in the AST */ + Expr source_expr{nullptr}; + /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ + Expr target_expr{nullptr}; +}; + +struct FixLocalBufferSize : public ir::IRMutator<> { + public: + FixLocalBufferSize(const std::string& tensor_name) : tensor_name_(tensor_name) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::Store* expr, Expr* op) override { + if (op->As()->tensor.As<_Tensor_>()->name == tensor_name_) { + op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)}; + op->As()->indices = {Expr(0)}; + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::Load* expr, Expr* op) override { + if (op->As()->tensor.As<_Tensor_>()->name == tensor_name_) { + op->As()->tensor.As<_Tensor_>()->shape = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->domain = {Expr(1)}; + op->As()->tensor.As<_Tensor_>()->buffer->shape = {Expr(1)}; + op->As()->indices = {Expr(0)}; + } + IRMutator::Visit(expr, op); + } + std::string tensor_name_; +}; + +void ScheduleImpl::SetBuffer(Expr& block, const std::string& memory_type, bool fixed) { + CHECK(block.As()); + auto find_tensor = ir::CollectIRNodesWithoutTensor( + block, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + auto& tensor = (*find_tensor.begin()).As()->tensor; + tensor.as_tensor_ref()->WithBuffer(memory_type, "_" + tensor.as_tensor_ref()->name + "_temp_buffer"); + + auto exprs = this->GetModule().GetExprs(); + for (auto& it_expr : exprs) { + auto find_tensor = ir::CollectIRNodesWithoutTensor(it_expr, [&](const Expr* x) { + return x->as_tensor() && (x->as_tensor()->name == tensor.as_tensor_ref()->name || + x->as_tensor()->name == tensor.as_tensor_ref()->name + "__reduce_init"); + }); + for (auto& t : find_tensor) { + CHECK(t.as_tensor()); + t.as_tensor_ref()->Bind(tensor.as_tensor_ref()->buffer); + } + } + + // if buffer type == "local" + if (memory_type == "local" && fixed) { + FixLocalBufferSize mutator(block.As()->schedule_block.As()->name); + auto root = GetRootBlock(block); + mutator(&root); + } +} + +void ScheduleImpl::MergeExprs() { + auto exprs = this->GetModule().GetExprs(); + if (exprs.size() == 1U) return; + CHECK(exprs[0].As()); + CHECK_EQ(exprs[0].As()->stmts.size(), 1U); + CHECK(exprs[0].As()->stmts[0].As()); + CHECK(exprs[0].As()->stmts[0].As()->schedule_block.As()); + std::vector merged_block; + merged_block.push_back( + exprs[0].As()->stmts[0].As()->schedule_block.As()->body); + VLOG(3) << "Before merging, exprs[0] is : " << exprs[0]; + for (int i = 1; i < exprs.size(); ++i) { + auto root_block = ir::CollectIRNodesWithoutTensor( + exprs[i], + [&](const Expr* x) { + return x->As() && x->As()->iter_values.empty(); + }, + true); + CHECK_EQ(root_block.size(), 1U); + for (auto& it_block : root_block) { + auto& block_body = it_block.As()->schedule_block.As()->body; + merged_block.push_back(block_body); + } + } + for (auto& block : merged_block) { + VLOG(3) << "in merged_block, it has " << block; + } + auto merged_expr = ir::Block::Make(merged_block); + exprs[0].As()->stmts[0].As()->schedule_block.As()->body = + merged_expr; + VLOG(3) << "After merging, exprs[0] is : " << exprs[0]; + exprs.erase(exprs.begin() + 1, exprs.end()); + this->SetExprs(exprs); +} + +void ScheduleImpl::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { + CHECK(block.As()); + CHECK(loop.As()); + Expr root = this->GetRootBlock(block); + + VLOG(3) << "Begin ComputeAt of loop:\n" << loop << "\nat block:\n" << root; + + auto producers = GetProducers(block, root); + auto consumers = GetConsumers(block, root); + CheckComputeAtValidation(block, loop, root); + LoopReconstructor reconstructor(root, block, loop); + LeafBlockRemovalPlan remove_plan(block, &reconstructor.source_expr, &reconstructor.target_expr); + remove_plan(&root); + auto iter_ranges = CalculateRequiredRegions(block, loop, root, consumers); + std::string new_var_names = reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, 0); + auto sch_block_expr = block.As()->schedule_block; + sch_block_expr.As()->attrs.emplace(ir::attr::compute_at_extra_var, new_var_names); + this->Replace(reconstructor.source_expr, reconstructor.target_expr); + this->Replace(reconstructor.loop_, reconstructor.new_loop_); + + VLOG(3) << "After SimpleComputeAt, ir is:\n" << reconstructor.new_loop_; +} + +void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { + CHECK(block.As()); + CHECK(loop.As()); + std::vector block_loops = this->GetLoops(block); + Expr root = this->GetRootBlock(block); + auto loops = GetLoopsOfExpr(loop, root); + + VLOG(3) << "Begin SimpleComputeAt of loop:\n" << loop << "\nat block:\n" << root; + + auto this_loop = loop; + auto block_name = GetTensor(block)->name; + auto this_block = block; + if (GetLoopExtent(loops[0]) == 1 && GetLoopExtent(block_loops[0]) != 1) { + this->Split(block_loops[0], {1, -1}); + this_block = this->GetBlock(block_name); + } else if (GetLoopExtent(loops[0]) != 1 && GetLoopExtent(block_loops[0]) == 1) { + auto splited = this->Split(loops[0], {1, -1}); + this_loop = splited[1]; + } + + block_loops = this->GetLoops(this_block); + root = this->GetRootBlock(this_block); + loops = GetLoopsOfExpr(this_loop, root); + + CHECK_LE(loops.size(), block_loops.size()); + + std::vector replaced_var; + std::vector substitute_expr; + for (int i = 0; i < loops.size(); ++i) { + CHECK_EQ(GetLoopExtent(loops[i]), GetLoopExtent(block_loops[i])); + if (block_loops[i].As()->bind_info().valid() && !loops[i].As()->bind_info().valid()) { + loops[i].As()->set_bind_info(block_loops[i].As()->bind_info()); + } + replaced_var.push_back(block_loops[i].As()->loop_var); + substitute_expr.push_back(Expr(loops[i].As()->loop_var)); + } + + Expr result = + loops.size() < block_loops.size() ? optim::IRCopy(block_loops[loops.size()]) : optim::IRCopy(this_block); + Expr new_loop = optim::IRCopy(this_loop); + + // Get the body of block_loop under the same loops + auto body = block_loops.at(loops.size() - 1).As()->body; + // collect if + auto if_checker = [](const Expr* x) { return x->As(); }; + auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker); + for (auto if_expr : if_set) { + auto checker = [block_name](const Expr* x) { + return x->As() && + x->As()->schedule_block.As()->name == block_name; + }; + if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) { + result = IfThenElse::Make(if_expr.As()->condition, result); + break; + } + } + + ReplaceExpr(&result, replaced_var, substitute_expr); + // When there are two identical IfThenElse + if (new_loop.As() && new_loop.As()->body.As() && + new_loop.As()->body.As()->stmts[0].As()) { + auto if_then_else = new_loop.As()->body.As()->stmts[0]; + if (result.As() && + if_then_else.As()->condition == result.As()->condition) { + new_loop.As()->body.As()->stmts[0].As()->true_case = + ir::Block::Make({result.As()->true_case, + new_loop.As()->body.As()->stmts[0].As()->true_case}); + } else { + std::vector::iterator pos = new_loop.As()->body.As()->stmts.begin(); + new_loop.As()->body.As()->stmts.insert(pos, result); + } + } else { + new_loop.As()->body = ir::Block::Make({result, new_loop.As()->body}); + } + + Expr source_expr{nullptr}; + Expr target_expr{nullptr}; + + LeafBlockRemovalPlan remove_plan( + result.As() ? block_loops[loops.size()] : this_block, &source_expr, &target_expr); + remove_plan(&root); + + this->Replace(source_expr, target_expr); + this->Replace(this_loop, new_loop); + + VLOG(3) << "After SimpleComputeAt, ir is:\n" << new_loop; +} + +void ScheduleImpl::ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { + CHECK(block.As()); + CHECK(loop.As()); + Expr root = this->GetRootBlock(block); + auto producers = GetProducers(block, root); + auto consumers = GetConsumers(block, root); + CheckComputeAtValidation(block, loop, root); + LoopReconstructor reconstructor(root, block, loop); + LeafBlockRemovalPlan remove_plan(block, &reconstructor.source_expr, &reconstructor.target_expr); + remove_plan(&root); + auto iter_ranges = CalculateRequiredRegions(block, loop, root, producers, false); + std::string new_var_names = reconstructor.MakeNewLoop(iter_ranges, keep_unit_loops, -1); + auto sch_block_expr = block.As()->schedule_block; + sch_block_expr.As()->attrs.emplace(ir::attr::reverse_compute_at_extra_var, new_var_names); + this->Replace(reconstructor.source_expr, reconstructor.target_expr); + this->Replace(reconstructor.loop_, reconstructor.new_loop_); + return; +} + +void BaseInliner::operator()(Expr* expr) { + IRMutator::Visit(&tgt_stmt, &tgt_stmt); + IRMutator::Visit(expr, expr); +} + +void BaseInliner::Visit(const ir::Block* expr, Expr* op) { + if (*op == src_stmt) { + *op = tgt_stmt; + return; + } + IRMutator::Visit(expr, op); +} + +bool BaseInliner::UpdateAndCheckIndexVars(const std::vector& indices, int expected_ndim) { + int n = indices.size(); + if (n != expected_ndim) { + return false; + } + std::vector result; + result.reserve(n); + for (auto& i : indices) { + if (i.as_var()) { + result.push_back(i.as_var_ref()); + } else { + return false; + } + } + int n_distinct = std::set(result.begin(), result.end()).size(); + if (n != n_distinct) { + return false; + } + if (idx_vars_.empty()) { + idx_vars_ = std::move(result); + } else { + if (idx_vars_.size() != result.size()) return false; + for (int i = 0; i < result.size(); ++i) { + if (Expr(idx_vars_[i]) != Expr(result[i])) return false; + } + } + return true; +} + +void BaseInliner::SetIndexSubstitution(const std::vector& indices) { + CHECK_EQ(indices.size(), idx_vars_.size()); + int n = idx_vars_.size(); + idx_sub_var_.reserve(n); + idx_sub_expr_.reserve(n); + for (int i = 0; i < n; ++i) { + idx_sub_var_.push_back(idx_vars_[i]); + idx_sub_expr_.push_back(indices[i]); + } +} + +bool ComputeInliner::BodyPatternAllowInline() { + if (!inlined_store_.defined()) { + return false; + } + CHECK(inlined_store_.As()); + auto find_vars = ir::CollectIRNodesWithoutTensor(inlined_store_, [&](const Expr* x) { return x->as_var(); }); + std::set vars_set; + for (auto& i : find_vars) vars_set.insert(i.as_var_ref()); + int n_vars = vars_set.size(); + if (!UpdateAndCheckIndexVars(inlined_store_.As()->indices, n_vars)) { + return false; + } + return true; +} + +void ComputeInliner::Visit(const ir::Load* expr, Expr* op) { + if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) { + *op = ReplaceInlinedTensor(op); + return; + } + IRMutator::Visit(expr, op); +} + +//! Replace the 'Load' node on the tensor to 'Load' node of its producers. +Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) { + CHECK(load->As()); + SetIndexSubstitution(load->As()->indices); + Expr value_copy = optim::IRCopy(inlined_store_.As()->value); + ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_); + return value_copy; +} + +void ScheduleImpl::ComputeInline(const Expr& schedule_block) { + CHECK(schedule_block.As()); + Expr root = this->GetRootBlock(schedule_block); + Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root); + ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), store); + CHECK(inliner.BodyPatternAllowInline()); + // Create a plan that removes the block to be inlined + LeafBlockRemovalPlan remove_plan(schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); + remove_plan(&root); + inliner(&root); + return; +} + +bool ComputeInlineChecker::Check() { + Expr root = ir_schedule_.GetRootBlock(block_); + store_ = CheckComputeInlineValidationAndGetStore(block_, root); + IRMutator::Visit(&root, &root); + return !should_skip_; +} + +void ComputeInlineChecker::BuildDataDependency() { + ir_schedule_.SetBuffer(block_, "shared", true); + auto loops = ir_schedule_.GetLoops(block_); + ir_schedule_.SyncThreads(loops.back(), true); +} + +bool ReverseComputeInliner::BodyPatternAllowInline() { + if (!inlined_store_.defined()) { + return false; + } + if (!inlined_load_.defined()) { + return false; + } + if (!target_store_.defined()) { + return false; + } + CHECK(inlined_store_.As()); + CHECK(inlined_load_.As()); + CHECK(target_store_.As()); + auto find_vars = ir::CollectIRNodesWithoutTensor(inlined_store_, [&](const Expr* x) { return x->as_var(); }); + std::set vars_set; + for (auto& i : find_vars) vars_set.insert(i.as_var_ref()); + int n_vars = vars_set.size(); + if (!UpdateAndCheckIndexVars(inlined_store_.As()->indices, n_vars)) { + return false; + } + return true; +} + +void ReverseComputeInliner::Visit(const ir::Load* expr, Expr* op) { + if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) { + *op = inlined_store_.As()->value; + return; + } + IRMutator::Visit(expr, op); +} + +void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) { + if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) { + *op = ReplaceTargetTensor(op); + return; + } + IRMutator::Visit(expr, op); +} + +//! Replace the 'Load' node on the tensor to 'Load' node of its producers. +Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) { + CHECK(load->As()); + SetIndexSubstitution(load->As()->indices); + Expr value_copy = optim::IRCopy(inlined_store_.As()->value); + return value_copy; +} + +Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) { + auto indices = inlined_load_.As()->indices; + CHECK_EQ(indices.size(), idx_vars_.size()); + size_t n = idx_vars_.size(); + idx_sub_var_.reserve(n); + idx_sub_expr_.reserve(n); + for (int i = 0; i < n; ++i) { + idx_sub_var_.emplace_back(indices[i].as_var_ref()); + idx_sub_expr_.emplace_back(idx_vars_[i]); + } + + Expr value_copy = optim::IRCopy(target_store_); + ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_); + return value_copy; +} + +void ScheduleImpl::ReverseComputeInline(const Expr& schedule_block) { + Expr root = this->GetRootBlock(schedule_block); + auto exprs = CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root); + Expr inlined_load = std::get<0>(exprs); + Expr inlined_store = std::get<1>(exprs); + Expr target_store = std::get<2>(exprs); + ReverseComputeInliner inliner( + inlined_store.As()->tensor.as_tensor_ref(), inlined_store, inlined_load, target_store); + CHECK(inliner.BodyPatternAllowInline()); + // Create a plan that removes the block to be inlined + LeafBlockRemovalPlan remove_plan(schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); + remove_plan(&root); + inliner(&root); + inliner(&root); +} + +struct FindBlockParent : public ir::IRMutator<> { + public: + FindBlockParent(const std::string& block_name) : block_name_(block_name) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::Block* expr, Expr* op) override { + if (target_) return; + for (auto& stmt : expr->stmts) { + if (stmt.As()) { + if (stmt.As()->schedule_block.As()->name == block_name_) { + target_ = op; + return; + } + } + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, Expr* op) override { + if (target_) return; + if (expr->body.As()) { + if (expr->body.As()->schedule_block.As()->name == block_name_) { + target_ = op; + return; + } + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::ScheduleBlock* expr, Expr* op) override { + if (target_) return; + if (expr->body.As()) { + if (expr->body.As()->schedule_block.As()->name == block_name_) { + target_ = op; + return; + } + } + IRMutator::Visit(expr, op); + } + + std::string block_name_; + + public: + ir::Expr* target_{nullptr}; +}; + +Expr ScheduleImpl::AddUnitLoop(const Expr& block) const { + auto exprs = module_expr_.GetExprs(); + CHECK(block.As()); + CHECK(block.As()->schedule_block.As()); + std::string block_name = block.As()->schedule_block.As()->name; + + FindBlockParent visitor(block_name); + for (auto expr : exprs) { + visitor(&expr); + if (visitor.target_) { + break; + } + } + + CHECK(visitor.target_) << ", block name : " << block_name << "\n" << exprs; + if (visitor.target_->As()) { + for (auto& stmt : visitor.target_->As()->stmts) { + if (stmt.As()) { + if (stmt.As()->schedule_block.As()->name == block_name) { + auto block = ir::Block::Make({GetBlock(block_name)}); + auto loop = ir::For::Make(ir::Var(common::UniqName("ix")), + ir::Expr(0), + ir::Expr(1), + ir::ForType::Serial, + ir::DeviceAPI::UNK, + block); + stmt = loop; + return loop; + } + } + } + } else if (visitor.target_->As()) { + auto block = ir::Block::Make({visitor.target_->As()->body}); + auto loop = ir::For::Make( + ir::Var(common::UniqName("ix")), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, block); + visitor.target_->As()->body = loop; + return loop; + } else if (visitor.target_->As()) { + auto block = ir::Block::Make({visitor.target_->As()->body}); + auto loop = ir::For::Make( + ir::Var(common::UniqName("ix")), ir::Expr(0), ir::Expr(1), ir::ForType::Serial, ir::DeviceAPI::UNK, block); + visitor.target_->As()->body = loop; + return loop; + } else { + LOG(FATAL) << "Can't find block's parent!"; + } + LOG(FATAL) << "Shouldn't reach code here in AddUnitLoop"; + return Expr{nullptr}; +} + +std::vector ScheduleImpl::GetLoops(const Expr& block) const { + std::vector result; + auto exprs = module_expr_.GetExprs(); + CHECK(block.As()); + CHECK(block.As()->schedule_block.As()); + std::string block_name = block.As()->schedule_block.As()->name; + + for (auto& it_expr : exprs) { + ir::FindLoopsVisitor visitor(block); + auto find_loops = visitor(&it_expr); + if (!find_loops.empty()) { + if (!result.empty()) LOG(FATAL) << "Find block with name: \n" << block_name << " appeared in more than one AST!"; + result = find_loops; + } + } + + if (result.empty()) { + result.push_back(AddUnitLoop(block)); + } + return result; +} + +std::vector ScheduleImpl::GetLoops(const std::string& block_name) const { + Expr block = this->GetBlock(block_name); + std::vector result = this->GetLoops(block); + return result; +} + +std::vector ScheduleImpl::GetAllBlocks() const { + std::vector result; + auto exprs = module_expr_.GetExprs(); + for (auto& it_expr : exprs) { + ir::FindBlocksVisitor visitor; + auto find_blocks = visitor(&it_expr); + result.insert(result.end(), find_blocks.begin(), find_blocks.end()); + } + for (auto& it_expr : exprs) { + VLOG(3) << "it_expr is : " << it_expr; + } + CHECK(!result.empty()) << "Didn't find blocks in expr."; + return result; +} + +std::vector ScheduleImpl::GetChildBlocks(const Expr& expr) const { + CHECK(expr.As() || expr.As()); + ir::FindBlocksVisitor visitor; + std::vector result = visitor(&expr); + return result; +} + +bool ScheduleImpl::HasBlock(const std::string& block_name) const { + auto exprs = module_expr_.GetExprs(); + for (auto& it_expr : exprs) { + ir::FindBlocksVisitor visitor(block_name); + auto find_blocks = visitor(&it_expr); + if (!find_blocks.empty()) { + CHECK_EQ(find_blocks.size(), 1U) << "There should not be more than 1 block with identical name!"; + return true; + } + } + return false; +} + +Expr ScheduleImpl::GetBlock(const std::string& block_name) const { + Expr result; + auto exprs = module_expr_.GetExprs(); + for (auto& it_expr : exprs) { + ir::FindBlocksVisitor visitor(block_name); + auto find_blocks = visitor(&it_expr); + if (!find_blocks.empty()) { + CHECK_EQ(find_blocks.size(), 1U) << "There should not be more than 1 block with identical name!"; + result = find_blocks[0]; + return result; + } + } + LOG(FATAL) << "Didn't find a block with name " << block_name << " in this ModuleExpr!"; +} + +void ScheduleImpl::Annotate(const Expr& block, const std::string& key, const attr_t& value) { + CHECK(block.As()); + CHECK(block.As()->schedule_block.As()); + auto copied_block = optim::IRCopy(block); + auto* schedule_block = copied_block.As()->schedule_block.As(); + schedule_block->attrs.emplace(key, value); + this->Replace(block, copied_block); +} + +void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) { + CHECK(block.As()); + CHECK(block.As()->schedule_block.As()); + auto* schedule_block = block.As()->schedule_block.As(); + if (schedule_block->attrs.count(ann_key)) { + schedule_block->attrs.erase(ann_key); + } else { + LOG(WARNING) << "Can't find annotation with key: " << ann_key; + return; + } +} + +void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_tensor) { + CHECK_GT(loops.size(), 0) << "Loops can't be empty!"; + VLOG(4) << "Before FlattenLoops, ir is:\n" << loops[0]; + // compute loop + int extent = 1; + std::vector strides; + std::vector loop_vars(loops.size()); + for (int idx = loops.size() - 1; idx >= 0; --idx) { + strides.insert(strides.begin(), extent); + extent *= loops[idx].As()->extent.as_int32(); + loop_vars[idx] = loops[idx].As()->loop_var; + } + CHECK_EQ(loops.size(), strides.size()); + + // create new loop. + auto last = loops.back().As(); + auto var = ir::Var("flat_i"); + auto _var = ir::Var("_flat_i"); + auto loop = ir::For::Make(var, ir::Expr(0), ir::Expr(extent), last->for_type(), last->device_api, last->body); + + // map loop var to old loop var. + auto _iter = ir::Expr(_var); + std::unordered_map loops_to_flat_var_map; + for (int idx = 0; idx < strides.size(); ++idx) { + if (strides[idx] == 1) { + // flat_i_to_loop_var.push_back(_iter); + loops_to_flat_var_map[loops[idx].As()->loop_var->name] = _iter; + } else { + // flat_i_to_loop_var.push_back(_iter / Expr(strides[idx])); + loops_to_flat_var_map[loops[idx].As()->loop_var->name] = _iter / Expr(strides[idx]); + _iter = _iter % Expr(strides[idx]); + } + } + + ir::FindBlocksVisitor visitor; + auto blocks = visitor(&last->body); + auto can_do_flat = [](const std::vector& indexs, const std::vector& loop_vars) { + if (indexs.size() != loop_vars.size()) { + return false; + } + + for (int idx = 0; idx < indexs.size(); ++idx) { + if (!indexs[idx].as_var()) { + return false; + } else { + auto var = indexs[idx].as_var_ref(); + if (var->name != loop_vars[idx]->name) { + return false; + } + } + } + return true; + }; + + // change blocks iter value/iter var + for (auto& block : blocks) { + auto block_realize = block.As(); + auto schedule_block = block_realize->schedule_block.As(); + + // checkout loops in orders. + std::vector var_names = {}; + CHECK_GE(block_realize->iter_values.size(), loop_vars.size()) + << "the number of iter bind values must be more than loop vars!"; + for (int idx = 0; idx < block_realize->iter_values.size(); ++idx) { + auto& iter = block_realize->iter_values[idx]; + if (iter.is_var()) { + CHECK_EQ(iter.as_var_ref()->name, loop_vars[idx]->name) << "loops is not the same order with tensor!"; + } else { + CHECK(iter.As()); + CHECK_EQ(iter.as_int32(), 0); + } + } + + auto exprs = ir::CollectIRNodesInOrder(schedule_block->body, + [&](const Expr* x) { return x->As() || x->As(); }); + // reverse exprs from last to first. + std::reverse(std::begin(exprs), std::end(exprs)); + + std::vector var_to_replace; + std::vector flat_i_to_loop_var; + // if iter var is more than flat i to loop, there exist dim = 1. + for (int idx = 0; idx < block_realize->iter_values.size(); ++idx) { + if (block_realize->iter_values[idx].is_var()) { + var_to_replace.push_back(schedule_block->iter_vars[idx]); + auto var_name = block_realize->iter_values[idx].as_var_ref()->name; + CHECK(loops_to_flat_var_map.count(var_name)) << "Can't find var name : " << var_name; + flat_i_to_loop_var.push_back(loops_to_flat_var_map[var_name]); + } else { + CHECK_EQ(block_realize->iter_values[idx].as_int32(), 0); + // insert var -> 0, to replace var to 0. + var_to_replace.push_back(schedule_block->iter_vars[idx]); + flat_i_to_loop_var.push_back(Expr(0)); + } + } + CHECK_EQ(var_to_replace.size(), flat_i_to_loop_var.size()); + + for (auto expr : exprs) { + if (expr.As()) { + auto store = expr.As(); + if (store->is_addr_tensor()) { + auto t = store->tensor.as_tensor_ref(); + CHECK(!t->reduce_axis.size()); + auto tsize = std::accumulate(t->shape.begin(), t->shape.end(), 1, [](const int sum, const Expr& expr) { + return sum * expr.as_int32(); + }); + if ((!flat_tensor && !can_do_flat(store->indices, schedule_block->iter_vars)) || extent != tsize) { + // just replace indexs + for (auto& indice : store->indices) { + if (!indice.is_var()) { + continue; + } + ReplaceExpr(&indice, var_to_replace, flat_i_to_loop_var); + } + // compute index and flat tensor. + store->indices = {store->index()}; + continue; + } + // update var and shape + store->indices = {Expr(_var)}; + } + } else { + auto load = expr.As(); + if (load->is_addr_tensor()) { + auto t = load->tensor.as_tensor_ref(); + CHECK(!t->reduce_axis.size()); + auto tsize = std::accumulate(t->shape.begin(), t->shape.end(), 1, [](const int sum, const Expr& expr) { + return sum * expr.as_int32(); + }); + if ((!flat_tensor && !can_do_flat(load->indices, schedule_block->iter_vars)) || extent != tsize) { + // just replace indexs + for (auto& indice : load->indices) { + if (!indice.is_var()) { + continue; + } + ReplaceExpr(&indice, var_to_replace, flat_i_to_loop_var); + } + // compute index and flat tensor. + load->indices = {load->index()}; + continue; + } + // update var and shape + load->indices = {Expr(_var)}; + } + } + } + ReplaceExpr(&schedule_block->body, var_to_replace, flat_i_to_loop_var); + + // update iter values + auto iter = ir::Expr(var); + block_realize->iter_values = {iter}; + + // update iter_vars + schedule_block->iter_vars = {_var}; + CHECK_EQ(block_realize->iter_values.size(), schedule_block->iter_vars.size()); + } + + this->Replace(loops[0], loop); + VLOG(4) << "After FlattenLoops, ir is:\n" << loop; +} + +void ScheduleImpl::CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name) { + auto block = this->GetBlock(block_name); + auto block_target = this->GetBlock(block_target_name); + this->CopyTransformAndLoopInfo(block, block_target); +} + +void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target) { + CHECK(block.As()); + CHECK(block_target.As()); + auto exprs = this->GetModule().GetExprs(); + CHECK_EQ(exprs.size(), 1U); + auto expr = exprs[0]; + auto vars = block.As()->schedule_block.As()->iter_vars; + auto vars_target = block_target.As()->schedule_block.As()->iter_vars; + auto old_iter_values = block.As()->iter_values; + auto iter_values_target = block_target.As()->iter_values; + std::vector new_iter_values; + for (int i = 0; i < vars.size() && i < vars_target.size(); ++i) { + CHECK(vars[i]->upper_bound.defined() && vars_target[i]->upper_bound.defined()); + if (vars[i]->upper_bound.is_constant() && vars_target[i]->upper_bound.is_constant() && + vars[i]->upper_bound.get_constant() == vars_target[i]->upper_bound.get_constant() && !vars[i]->is_reduce_axis && + !vars_target[i]->is_reduce_axis) { + new_iter_values.push_back(iter_values_target[i]); + VLOG(3) << "new_iter_values.push_back " << iter_values_target[i]; + } else + break; + } + + if (new_iter_values.empty()) + LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source and target is not equal! " + << vars[0]->upper_bound << " v.s " << vars_target[0]->upper_bound; + + int changed_loop_num = new_iter_values.size(); + std::set used_target_loop_vars; + for (auto& iter_val : new_iter_values) { + auto find_partial_loop = ir::CollectIRNodesWithoutTensor(iter_val, [&](const Expr* x) { + if (x->as_var()) used_target_loop_vars.insert(x->as_var_ref()->name); + return x->as_var(); + }); + } + CHECK(!used_target_loop_vars.empty()); + std::vector used_target_loops; + auto expr_copy = optim::IRCopy(expr); + for (auto& var : used_target_loop_vars) { + auto find_loop_var = ir::CollectIRNodesWithoutTensor( + expr_copy, + [&](const Expr* x) { + return x->As() && x->As()->loop_var->name == var && Contains(*x, block_target); + }, + true); + CHECK_EQ(find_loop_var.size(), 1U); + used_target_loops.push_back(*find_loop_var.begin()); + VLOG(3) << "used_target_loops push_back " << used_target_loops.back(); + } + std::sort(used_target_loops.begin(), used_target_loops.end(), [&](Expr i, Expr j) { + return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); + }); + for (int i = new_iter_values.size(); i < old_iter_values.size(); ++i) { + CHECK(old_iter_values[i].as_var()); + new_iter_values.push_back(old_iter_values[i]); + } + Expr new_loop; + VLOG(3) << "changed_loop_num is : " << changed_loop_num; + VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size(); + if (changed_loop_num >= (int)old_iter_values.size()) { + new_loop = optim::IRCopy(block); + new_loop.As()->iter_values = new_iter_values; + } else { + CHECK(old_iter_values[changed_loop_num].as_var()); + auto old_var = old_iter_values[changed_loop_num].as_var_ref(); + auto find_partial_loop = ir::CollectIRNodesWithoutTensor( + expr, + [&](const Expr* x) { + return x->As() && x->As()->loop_var->name == old_var->name && Contains(*x, block); + }, + true); + CHECK_EQ(find_partial_loop.size(), 1U); + new_loop = optim::IRCopy(*find_partial_loop.begin()); + auto find_schedule_block = ir::CollectIRNodesWithoutTensor( + new_loop, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_schedule_block.size(), 1U); + Expr sch_block = (*find_schedule_block.begin()); + sch_block.As()->iter_values = new_iter_values; + } + VLOG(3) << "new_loop is : " << new_loop; + CHECK(!used_target_loops.empty()); + Expr res; + if (used_target_loops.size() == 1) { + auto for_loop = used_target_loops[0].As(); + res = For::Make(for_loop->loop_var, + for_loop->min, + for_loop->extent, + for_loop->for_type(), + for_loop->device_api, + new_loop, + for_loop->vectorize_info(), + for_loop->bind_info()); + } else { + Expr outer_loop = used_target_loops.front(); + Expr inner_loop = used_target_loops.back(); + inner_loop.As()->body = Block::Make({new_loop}); + res = outer_loop; + } + VLOG(3) << "res is : " << res; + std::vector all_loops = this->GetLoops(block); + CHECK(!all_loops.empty()); + this->Replace(all_loops[0], res); +} + +std::vector ScheduleImpl::SamplePerfectTile(utils::LinearRandomEngine::StateType* rand_seed, + const Expr& loop, + int n, + int max_innermost_factor) { + CHECK(loop.As()) << "Expr param of SamplePerfectTile should be a For loop"; + CHECK_GE(n, 2) << "The number of tile factors should be at least 2"; + CHECK_GE(max_innermost_factor, 1) << "The max innermost factor should be at least 1"; + CHECK(common::is_zero(loop.As()->min)) << "The For loop should start from 0"; + int loop_extent = GetLoopExtent(loop); + std::vector innermost_factors; + for (int i = max_innermost_factor; i >= 1; --i) { + if (loop_extent % i == 0) { + innermost_factors.push_back(i); + } + } + CHECK(!innermost_factors.empty()) << "No innermost factor found"; + int innermost_factor = innermost_factors[utils::SampleUniformInt(0, innermost_factors.size(), rand_seed)]; + auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor); + std::vector result_expr; + for (auto& factor : result) { + result_expr.push_back(Expr(factor)); + } + result_expr.push_back(Expr(innermost_factor)); + return result_expr; +} + +Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_seed, + const std::vector& candidates, + const std::vector& probs) { + // check two sizes + CHECK_EQ(candidates.size(), probs.size()) << "candidates and probs must have same size."; + int seed_idx = utils::SampleDiscreteFromDistribution(probs, rand_seed); + auto result = candidates[seed_idx]; + Expr result_expr(result); + return result_expr; +} + +IRSchedule::IRSchedule() {} + +IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) { + impl_ = std::make_unique(module_expr, debug_flag); + this->InitSeed(rand_seed); +} + +IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed) + : impl_(std::make_unique(std::move(mod_expr))), trace_(std::move(trace)) { + this->InitSeed(rand_seed); +} + +IRSchedule::IRSchedule(const IRSchedule& other) + : impl_(std::make_unique(optim::IRCopy(other.GetModule()))), trace_(other.trace_) { + this->InitSeed(other.ForkSeed()); +} + +IRSchedule& IRSchedule::operator=(const IRSchedule& src) { + impl_ = std::make_unique(optim::IRCopy(src.GetModule())); + trace_ = src.trace_; + this->InitSeed(src.ForkSeed()); + return *this; +} + +IRSchedule::IRSchedule(IRSchedule&& other) : impl_(std::move(other.impl_)), trace_(std::move(other.trace_)) { + this->InitSeed(other.ForkSeed()); +} + +IRSchedule& IRSchedule::operator=(IRSchedule&& src) { + impl_ = std::move(src.impl_); + trace_ = std::move(src.trace_); + this->InitSeed(src.ForkSeed()); + return *this; +} + +IRSchedule::~IRSchedule() {} + +void IRSchedule::InitSeed(utils::LinearRandomEngine::StateType rand_seed) { + this->rand_seed_ = utils::LinearRandomEngine::NormalizeState(rand_seed); +} + +utils::LinearRandomEngine::StateType IRSchedule::ForkSeed() const { return utils::ForkRandomState(&rand_seed_); } + +void IRSchedule::SetExprs(const std::vector& exprs) { + return impl_->SetExprs(exprs); + // no need to trace +} + +const ModuleExpr& IRSchedule::GetModule() const { + return impl_->GetModule(); + // no need to trace +} + +bool IRSchedule::HasBlock(const std::string& block_name) const { + return impl_->HasBlock(block_name); + // no need to trace +} + +void IRSchedule::MergeExprs() { + impl_->MergeExprs(); + trace_.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {})); +} + +std::vector IRSchedule::GetLoops(const Expr& block) const { + auto results = impl_->GetLoops(block); + trace_.Append(ScheduleDesc::Step("GetLoops", {{"block", std::vector({block})}}, {}, results)); + return results; +} + +std::vector IRSchedule::GetLoops(const std::string& block_name) const { + auto results = impl_->GetLoops(block_name); + trace_.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", block_name}}, results)); + return results; +} + +std::vector IRSchedule::GetAllBlocks() const { + auto results = impl_->GetAllBlocks(); + trace_.Append(ScheduleDesc::Step("GetAllBlocks", {}, {}, results)); + return results; +} + +std::vector IRSchedule::GetChildBlocks(const Expr& expr) const { + auto results = impl_->GetChildBlocks(expr); + trace_.Append(ScheduleDesc::Step("GetChildBlocks", {{"expr", std::vector({expr})}}, {}, results)); + return results; +} + +Expr IRSchedule::GetBlock(const std::string& block_name) const { + auto result = impl_->GetBlock(block_name); + trace_.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", block_name}}, {result})); + return result; +} + +std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { + std::vector decision = SamplePerfectTile(loop, factors.size(), loop.As()->extent.as_int32(), factors); + auto results = Split(loop, decision); + return results; +} + +std::vector IRSchedule::Split(const std::string& block_name, int loop_index, const std::vector& factors) { + std::vector all_loops = this->GetLoops(block_name); + Expr loop_expr; + CHECK_LT(loop_index, (int)all_loops.size()) << "The loop index in Split should be less than total loop's number."; + CHECK_GE(loop_index, 0) << "The loop index in Split should be >= 0."; + loop_expr = all_loops[loop_index]; + + return this->Split(loop_expr, factors); +} + +std::vector IRSchedule::Split(const Expr& loop, const std::vector& factors) { + std::vector int_factors; + std::transform(factors.begin(), factors.end(), std::back_inserter(int_factors), [](Expr x) { return x.as_int32(); }); + auto results = impl_->Split(loop, int_factors); + trace_.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({loop})}, {"factors", factors}}, {}, results)); + return results; +} + +Expr IRSchedule::Fuse(const std::vector& loops) { + auto result = impl_->Fuse(loops); + trace_.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {result})); + return result; +} + +Expr IRSchedule::Fuse(const std::string& block_name, const std::vector& loops_index) { + auto result = impl_->Fuse(block_name, loops_index); + trace_.Append( + ScheduleDesc::Step("FuseWithName", {}, {{"block_name", block_name}, {"loops_index", loops_index}}, {result})); + return result; +} + +Expr IRSchedule::Fuse(const Expr& block, const std::vector& loops_index) { + auto result = impl_->Fuse(block, loops_index); + trace_.Append(ScheduleDesc::Step( + "FuseWithBlock", {{"block", std::vector({block})}}, {{"loops_index", loops_index}}, {result})); + return result; +} + +void IRSchedule::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { + impl_->ComputeAt(block, loop, keep_unit_loops); + trace_.Append(ScheduleDesc::Step("ComputeAt", + {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, + {{"keep_unit_loops", keep_unit_loops}}, + {})); +} + +void IRSchedule::SimpleComputeAt(const Expr& block, const Expr& loop) { + impl_->SimpleComputeAt(block, loop); + trace_.Append(ScheduleDesc::Step( + "SimpleComputeAt", {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, {}, {})); +} + +void IRSchedule::ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { + impl_->ReverseComputeAt(block, loop, keep_unit_loops); + trace_.Append(ScheduleDesc::Step("ReverseComputeAt", + {{"block", std::vector({block})}, {"loop", std::vector({loop})}}, + {{"keep_unit_loops", keep_unit_loops}}, + {})); +} + +Expr IRSchedule::GetRootBlock(const Expr& expr) const { + auto result = impl_->GetRootBlock(expr); + trace_.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({expr})}}, {}, {result})); + return result; +} + +Expr IRSchedule::CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type) { + auto result = impl_->CacheRead(block, read_buffer_index, memory_type); + trace_.Append(ScheduleDesc::Step("CacheRead", + {{"block", std::vector({block})}}, + {{"read_buffer_index", read_buffer_index}, {"memory_type", memory_type}}, + {result})); + return result; +} + +Expr IRSchedule::CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type) { + auto result = impl_->CacheWrite(block, write_buffer_index, memory_type); + trace_.Append(ScheduleDesc::Step("CacheWrite", + {{"block", std::vector({block})}}, + {{"write_buffer_index", write_buffer_index}, {"memory_type", memory_type}}, + {result})); + return result; +} + +void IRSchedule::SyncThreads(const Expr& ir_node, bool after_node) { + impl_->SyncThreads(ir_node, after_node); + trace_.Append( + ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({ir_node})}}, {{"after_node", after_node}}, {})); +} + +void IRSchedule::SetBuffer(Expr& block, const std::string& memory_type, bool fixed) { + impl_->SetBuffer(block, memory_type, fixed); + trace_.Append(ScheduleDesc::Step( + "SetBuffer", {{"block", std::vector({block})}}, {{"memory_type", memory_type}, {"fixed", fixed}}, {})); +} + +Expr IRSchedule::Reorder(const std::vector& loops) { + Expr ret = impl_->Reorder(loops); + trace_.Append(ScheduleDesc::Step("Reorder", {{"loops", loops}}, {}, {ret})); + return ret; +} + +Expr IRSchedule::Reorder(const std::string& block_name, const std::vector& loops_index) { + Expr ret = impl_->Reorder(block_name, loops_index); + trace_.Append( + ScheduleDesc::Step("ReorderWithName", {}, {{"block_name", block_name}, {"loops_index", loops_index}}, {ret})); + return ret; +} + +Expr IRSchedule::Reorder(const Expr& block, const std::vector& loops_index) { + Expr ret = impl_->Reorder(block, loops_index); + trace_.Append(ScheduleDesc::Step( + "ReorderWithBlock", {{"block", std::vector({block})}}, {{"loops_index", loops_index}}, {ret})); + return ret; +} + +void IRSchedule::Parallel(const Expr& loop) { + impl_->Parallel(loop); + trace_.Append(ScheduleDesc::Step("Parallel", {{"loop", std::vector({loop})}}, {}, {})); +} + +void IRSchedule::Vectorize(const Expr& loop, int factor) { + impl_->Vectorize(loop, factor); + trace_.Append(ScheduleDesc::Step("Vectorize", {{"loop", std::vector({loop})}}, {{"factor", factor}}, {})); +} + +void IRSchedule::Unroll(const Expr& loop) { + impl_->Unroll(loop); + trace_.Append(ScheduleDesc::Step("Unroll", {{"loop", std::vector({loop})}}, {}, {})); +} + +void IRSchedule::ComputeInline(const Expr& schedule_block) { + impl_->ComputeInline(schedule_block); + trace_.Append(ScheduleDesc::Step("ComputeInline", {{"schedule_block", std::vector({schedule_block})}}, {}, {})); +} + +void IRSchedule::ReverseComputeInline(const Expr& schedule_block) { + impl_->ReverseComputeInline(schedule_block); + trace_.Append( + ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector({schedule_block})}}, {}, {})); +} + +void IRSchedule::Bind(const Expr& loop, const std::string& thread_axis) { + impl_->Bind(loop, thread_axis); + trace_.Append(ScheduleDesc::Step("Bind", {{"loop", std::vector({loop})}}, {{"thread_axis", thread_axis}}, {})); +} + +Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) { + auto result = impl_->Rfactor(rf_loop, rf_axis); + trace_.Append( + ScheduleDesc::Step("Rfactor", {{"rf_loop", std::vector({rf_loop})}}, {{"rf_axis", rf_axis}}, {result})); + return result; +} + +void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_t& value) { + impl_->Annotate(block, key, value); + +#define TRACE_ANNOTATE_ITEM(data_type, step_name) \ + if (absl::holds_alternative(value)) { \ + trace_.Append(ScheduleDesc::Step(#step_name, \ + {{"block", std::vector({block})}}, \ + {{"key", key}, {"value", absl::get(value)}}, \ + {})); \ + return; \ + } + TRACE_ANNOTATE_ITEM(int, AnnotateIntAttr) + TRACE_ANNOTATE_ITEM(bool, AnnotateBoolAttr) + TRACE_ANNOTATE_ITEM(float, AnnotateFloatAttr) + TRACE_ANNOTATE_ITEM(std::string, AnnotateStringAttr) +#undef TRACE_ANNOTATE_ITEM + + LOG(FATAL) << "Value of attribute:" << key << " input unsupported data type"; +} + +void IRSchedule::Unannotate(Expr& block, const std::string& key) { + impl_->Unannotate(block, key); + trace_.Append(ScheduleDesc::Step("Unannotate", {{"block", std::vector({block})}}, {{"key", key}}, {})); +} + +void IRSchedule::FlattenLoops(const std::vector& loops, const bool force_flat) { + impl_->FlattenLoops(loops, force_flat); + trace_.Append( + ScheduleDesc::Step("FlattenLoops", {{"loop", std::vector({loops})}}, {{"force_flat", force_flat}}, {})); +} + +void IRSchedule::CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target) { + impl_->CopyTransformAndLoopInfo(block, block_target); + // don't support to trace, because we can't ensure both blocks are from the same ModuleExpr +} + +void IRSchedule::CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name) { + impl_->CopyTransformAndLoopInfo(block_name, block_target_name); + // don't support to trace, because we can't ensure both blocks are from the same ModuleExpr +} + +std::vector IRSchedule::SamplePerfectTile(const Expr& loop, + int n, + int max_innermost_factor, + const std::vector& decision) { + std::vector factors; + std::vector new_decision; + if (decision.empty()) { + factors = impl_->SamplePerfectTile(&rand_seed_, loop, n, max_innermost_factor); + std::transform( + factors.begin(), factors.end(), std::back_inserter(new_decision), [](Expr x) { return x.as_int32(); }); + } else { + new_decision = decision; + std::transform(decision.begin(), decision.end(), std::back_inserter(factors), [](int x) { return Expr(x); }); + } + trace_.Append( + ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loop})}}, + {{"n", n}, {"max_innermost_factor", max_innermost_factor}, {"decision", new_decision}}, + factors)); + return factors; +} + +void IRSchedule::TagPostSchedule() { trace_.Append(ScheduleDesc::Step("TagPostSchedule", {}, {}, {})); } + +Expr IRSchedule::SampleCategorical(const std::vector& candidates, + const std::vector& probs, + const std::vector& decision) { + Expr result; + std::vector new_decision; + if (decision.empty()) { + result = impl_->SampleCategorical(&rand_seed_, candidates, probs); + new_decision.push_back(result.as_int32()); + } else { + new_decision = decision; + for (auto ndco : new_decision) { + result = Expr(ndco); + } + } + trace_.Append(ScheduleDesc::Step( + "SampleCategorical", {}, {{"candidates", candidates}, {"probs", probs}, {"decision", new_decision}}, {result})); + return result; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule.h b/paddle/cinn/ir/ir_schedule.h new file mode 100644 index 0000000000000..6b7b252a57dec --- /dev/null +++ b/paddle/cinn/ir/ir_schedule.h @@ -0,0 +1,614 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/schedule_desc.h" +#include "cinn/ir/tensor.h" +#include "cinn/utils/random_engine.h" + +namespace cinn { +namespace ir { + +/** + * A struct representing a module that contains Expr. This struct is only used in Schedule process. + */ +class ModuleExpr { + public: + ModuleExpr() = default; + ModuleExpr(const ModuleExpr& mod_expr) = default; + ModuleExpr(ModuleExpr&& mod_expr) = default; + + ModuleExpr& operator=(const ModuleExpr& mod_expr) = default; + + explicit ModuleExpr(const std::vector& exprs) : exprs_(exprs) {} + explicit ModuleExpr(std::vector&& exprs) : exprs_(std::move(exprs)) {} + + //! Get all the Expr in this ModuleExpr. + std::vector GetExprs() { return exprs_; } + + std::vector GetExprs() const { return exprs_; } + + void SetExprs(const std::vector& exprs) { exprs_ = exprs; } + + private: + //! Exprs stored in ModuleExpr. Each one is an AST, representing a computation kernel. + std::vector exprs_; +}; + +/** + * A struct containing all the schedule primitives. Each shedule primitive is a member function of IRSchedule. + * Schedule primitves are implmented by ScheduleImpl manipulating the AST - IR(Expr). + * To support serializing and replaying, each schedule primitive should append a ScheduleDesc::Step to + * the trace_ in its corresponding function implment. + */ +class ScheduleImpl; +class IRSchedule { + public: + IRSchedule(); + explicit IRSchedule(const ModuleExpr& modexpr, + utils::LinearRandomEngine::StateType rand_seed = -1, + bool debug_flag = false); + IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1); + IRSchedule(const IRSchedule& other); + IRSchedule& operator=(const IRSchedule& src); + IRSchedule(IRSchedule&& other); + IRSchedule& operator=(IRSchedule&& src); + ~IRSchedule(); + + void SetExprs(const std::vector& exprs); + + //! Get the ModuleExpr stored in ScheduleImpl. + const ModuleExpr& GetModule() const; + + //! Determine whether a specific block is included + bool HasBlock(const std::string& block_name) const; + + //! Merge multiple Exprs in a ModuleExpr to be one + void MergeExprs(); + + //! Get the ScheduleDesc that traces the scheduling process + const ScheduleDesc& GetTraceDesc() const { return trace_; } + + /** + * \brief Get all the loops of specific Block stored in ModuleExpr. + * @param block The block we find loop in. + * @return Loops of the block. + */ + std::vector GetLoops(const Expr& block) const; + + /** + * \brief Get all the loops of specific Block stored in ModuleExpr. + * @param block_name Name of the block. + * @return Loops of the block. + */ + std::vector GetLoops(const std::string& block_name) const; + + //! Get all blocks stored in this ModuleExpr. + std::vector GetAllBlocks() const; + + //! Get a block with the specific name. + Expr GetBlock(const std::string& block_name) const; + + /** + * \brief Get all the childblocks of specific Expr stored in ModuleExpr. + * @param expr The expr we find childblock in, can be a loop or block. + * @return ChildBlocks of the block. + */ + std::vector GetChildBlocks(const Expr& expr) const; + + /** + * \brief Split a for loop into multiple loops, based on the factors. + * @param loop The loop to be splited. + * @param factors The factors we used to split the loop. + * @return The splited loops. + */ + std::vector Split(const Expr& loop, const std::vector& factors); + + /** + * \brief Split a for loop into multiple loops, based on the factors. + * @param block_name Name of the block we want to modify. + * @param loop_index Index of the loop to be splited. + * @param factors The factors we used to split the loop. + * @return The splited loops. + */ + std::vector Split(const std::string& block_name, int loop_index, const std::vector& factors); + + /** + * \brief Split a for loop into multiple loops, based on the factors, only used for deserialization of trace. + * @param loop The loop to be splited. + * @param factors The factors we used to split the loop. + * @return The splited loops. + */ + std::vector Split(const Expr& loop, const std::vector& factors); + + /** + * \brief Fuse for loops and return the fused loop. + * @param loops All the loops to be fused, stored in ascending order. + * @return The fused loop. + */ + Expr Fuse(const std::vector& loops); + + /** + * \brief Fuse for loops and return the fused loop. + * @param block_name Name of the block we want to modify. + * @param loops_index Indices of the loops to be fused, stored in ascending order. + * @return The fused loop. + */ + Expr Fuse(const std::string& block_name, const std::vector& loops_index); + + /** + * \brief Fuse for loops and return the fused loop. + * @param block The block we want to modify. + * @param loops_index Indices of the loops to be fused, stored in ascending order. + * @return The fused loop. + */ + Expr Fuse(const Expr& block, const std::vector& loops_index); + + /** + * \brief Move a producer block's location under a specific loop. + * @param block The block we want to move its computation location. + * @param loop The loop we will move the block to. + * @param keep_unit_loops Whether to keep the unit loop. + */ + void ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops = false); + + /** + * \brief Move a block's location under a loop without considering their dependency. + * @param block The block we want to move its computation location. + * @param loop The loop we will move the block to. + */ + void SimpleComputeAt(const Expr& block, const Expr& loop); + + /** + * \brief Move a consumer block's location under a specific loop. + * @param block The block we want to move its computation location. + * @param loop The loop we will move the block to. + * @param keep_unit_loops Whether to keep the unit loop. + */ + void ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops = false); + + /** + * \brief Find an expr's root ScheduleBlockRealize node + * @param expr The expr node. + * @return Its root ScheduleBlockRealize node. + */ + Expr GetRootBlock(const Expr& expr) const; + + /** + * \brief Find a buffer that is being read, and create its cache. + * @param block Block that reads the buffer. + * @param read_buffer_index Index of the buffer being read in block. + * @param memory_type String that indicates the buffer's storage scope. + * @return The buffer's cache. + */ + Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); + + /** + * \brief Find a buffer that is being written, and create its cache. + * @param block Block that writes the buffer. + * @param write_buffer_index Index of the buffer being written in block. + * @param memory_type String that indicates the buffer's storage scope. + * @return The buffer's cache. + */ + Expr CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type); + + /** + * \brief Add SyncThreads statements in AST. + * @param ir_node The insertion point in AST. + * @param after_node Whether to insert the statement after the insertion point. When it is True, we will insert the + * SyncThreads statement after the insertion IR. When it is False, we will insert the SyncThreads statement before the + * insertion IR. + */ + void SyncThreads(const Expr& ir_node, bool after_node = true); + + /*! + * \brief Set a tensor's buffer type(memory_type) + * \param block The ScheduleBlockRealize corresponding to an unique tensor. + * \param memory_type The memory type we want to set. Should be "local", "shared" or "global". + */ + void SetBuffer(Expr& block, const std::string& memory_type, bool fixed = false); + + /** + * \brief Reorder the loops in the order of vector. + * @param loops The loops to be reordered. + * @return The reordered Expr, can be ir::For or ir::Block. It is ir::For if + * the reordered loop is a single loop chain. It will be a ir::Block whose + * stmts contain several loop chains if the reordered computation has + * multiple loop chains. + */ + Expr Reorder(const std::vector& loops); + + /** + * \brief Reorder the loops in the order of vector elements. + * @param block_name Name of the block we want to modify. + * @param loops_index Indices of loops to be reordered. + * @return The reordered Expr, can be ir::For or ir::Block. It is ir::For if + * the reordered loop is a single loop chain. It will be a ir::Block whose + * stmts contain several loop chains if the reordered computation has + * multiple loop chains. + */ + Expr Reorder(const std::string& block_name, const std::vector& loops_index); + + /** + * \brief Reorder the loops in the order of vector elements. + * @param block The block we want to modify. + * @param loops_index Indices of loops to be reordered. + * @return The reordered Expr, can be ir::For or ir::Block. It is ir::For if + * the reordered loop is a single loop chain. It will be a ir::Block whose + * stmts contain several loop chains if the reordered computation has + * multiple loop chains. + */ + Expr Reorder(const Expr& block, const std::vector& loops_index); + + /** + * Get the device api of this IRSchedule. + * @param return The device api of this IRSchedule. + */ + DeviceAPI GetDeviceAPI() const; + + /** + * \brief Change forloop to be parallelized/vectorized/unrolled. + * @param loop The forloop to parallel/vectorize/unroll. + * @param for_type the target forloop type. + */ + void MutateForType(const Expr& loop, ForType for_type, int factor = -1); + + /** + * \brief Parallelize the given loop. + * @param loop the loop to parallel. + */ + void Parallel(const Expr& loop); + + /** + * \brief Vectorize the given loop. + * @param loop the loop to vectorize. + * @param factor the vectorized factor. + */ + void Vectorize(const Expr& loop, int factor); + + /** + * \brief Unroll the given loop. + * @param loop the loop to unroll. + */ + void Unroll(const Expr& loop); + + /** + * \brief Mark an schedule block as inlined. + * @param schedule_block the schedule block to be inlined. + */ + void ComputeInline(const Expr& schedule_block); + + /** + * \brief Inline a consumer block into its only producer. + * @param schedule_block the schedule block to be inlined. + */ + void ReverseComputeInline(const Expr& schedule_block); + + /** + * \brief Bind the loop to the given thread axis. + * @param loop the loop to Bind. + * @param thread_axis the name of the thread axis to be bound to the loop. + */ + void Bind(const Expr& loop, const std::string& thread_axis); + + //! Copy another block's schedule transform. + void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target); + + void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name); + + /** + * \brief Factorize the reduction block by the given loop. The block will be split into two blocks: rfactor block and + * final write-back block. + * @param rf_loop the reduce loop to do rfactor transformation. + * @param rf_axis the axis where the new generated loop is placed in the rfactor block. + * @return The new created rfactor tensor. + * + * For example, input the block: + * \code + * for (i, 0, 10) // serial loop + * B_init[i] = 0 + * for (j, 0, 20) // reduce loop + * for (k, 0, 30) // reduce loop + * B[i] = B[i] + A[i, j, k] + * \endcode + * + * If the rfactor loop is k and rf_axis is 0, the rfactor transformation is divided into 2 steps: + * 1. get the rfactor block where the reduce loop k is transformed to the serial loop with no accumalation and a new + * rfactor tensor is created. The axis k will be placed in the rf_axis of the new rf_tensor. The rf_block is as + * follows: + * \code + * for (rf_k, 0, 30) // rfactor loop k is transformed to the serial loop. + * for (i, 0, 10) // serial loop for (j, 0, 20) // reduce loop + * rf_B_init[rf_k, i] = 0 + * for (j, 0, 20) // reduce loop + * rf_B[rf_k, i] = rf_B[rf_k, i] + A[i, j, rf_k] + * \endcode + * 2. do reduction of the rfactor loop k to get the final result block: + * \code + * for (i, 0, 10) // serial loop + * B_init[i] = 0 + * for (k, 0, 30) + * B[i] = B[i] + rf_B[k, i] + * \endcode + */ + Expr Rfactor(const Expr& rf_loop, int rf_axis); + + /*! + * \brief Annotate a block with a key-value pair to set as its attribute + * \param block The block to be annotated + * \param key The attribute key + * \param val The attribute value, its type should be one of attr_t listing + */ + void Annotate(const Expr& block, const std::string& key, const attr_t& value); + + /*! + * \brief To cancel an annotation within a block using the key + * \param block The block to be unannotated + * \param key The attribute key + */ + void Unannotate(Expr& block, const std::string& key); + + /*! + * \brief flatten the loops in one dim. + * \param loops the loops to be flatted. + * \param force_flat force to flat the right value. + */ + // Temporary solution for simplify the elementwise/broadcast/injective index. + // TODO(sunli): Solve Index Simplify. + void FlattenLoops(const std::vector& loops, const bool force_flat = false); + + /*! + * \brief Sample the factors to tile a specific loop perfectly + * \param loop the loop to be split + * \param n the number of loop layers to split + * \param max_innermost_factor the maximum factor of the innermost loop + * \param decision the decision data of the last sample, or the artificially given decision data + * \return the split factors of the loop (The larger the index, the inner the corresponding loop) + * For example, return {16,64} means the loop will be like this: + * for (i, 0, 16) { + * for (j, 0, 64) { + * ... + * } + * } + */ + std::vector SamplePerfectTile(const Expr& loop, + int n, + int max_innermost_factor, + const std::vector& decision = {}); + + /*! + * \brief Insert a tag in schedule_desc to mark the beginning of post processing, + * the schedue primitive itself does not make any changes to the IR. + */ + void TagPostSchedule(); + + /** + * \brief Randomly sample an integer according to the given distribution. + * @param candidates Candidate set of integers. + * @param probs Probability distribution of candidate integer set. + * @param decision the decision data of the last sample, or the artificially given decision data. + * @return Random variables sampled. + */ + Expr SampleCategorical(const std::vector& candidates, + const std::vector& probs, + const std::vector& decision = {}); + + private: + // Init the random seed with a new seed + void InitSeed(utils::LinearRandomEngine::StateType rand_seed); + + // Fork a new seed from current seed + utils::LinearRandomEngine::StateType ForkSeed() const; + + private: + std::unique_ptr impl_; + mutable ScheduleDesc trace_; // trace the scheduling process + mutable utils::LinearRandomEngine::StateType rand_seed_; +}; + +/*! + * \brief The base class of the inliner, which handles: + * 1) Remove the block to be lined + * 2) Maintain a list of index variables and their substition of the buffer being inlined + */ +class BaseInliner : public ir::IRMutator<> { + protected: + explicit BaseInliner(const Tensor& inlined_tensor, const Expr& inlined_store) + : inlined_tensor_(inlined_tensor), inlined_store_(inlined_store) {} + + public: + void operator()(Expr* expr); + + private: + void Visit(const ir::Block* expr, Expr* op) override; + + protected: + //! Check if indices are validate. If so, set idx_vars_ properly. + bool UpdateAndCheckIndexVars(const std::vector& indices, int expected_ndim); + + void SetIndexSubstitution(const std::vector& indices); + + protected: + //! The tensor to be inlined + Tensor inlined_tensor_{nullptr}; + //! The body of the block to be inlined + Expr inlined_store_{nullptr}; + //! The indices used for indexing the buffer to be inlined + std::vector idx_vars_; + //! Replacing vars(idx_sub_var_) in indices to corresponding expr(idx_sub_expr_) + std::vector idx_sub_var_; + std::vector idx_sub_expr_; + + public: + /*! + * \brief The Expr to be replaced when removing the block + * \note The pair (src_stmt, tgt_stmt) are produced by LeafBlockRemovalPlan + */ + Expr src_stmt{nullptr}; + //! The Expr to replace the original one when removing the block + Expr tgt_stmt{nullptr}; +}; + +/*! + * \brief Helper to inline the producer block into its consumer(s) + * The derived class implements: + * Substitute `Load` on the tensor to be inlined to its value calculation in the producer block + */ +class ComputeInliner : public BaseInliner { + public: + explicit ComputeInliner(const Tensor& inlined_tensor, const Expr& inlined_store) + : BaseInliner(inlined_tensor, inlined_store) {} + + bool BodyPatternAllowInline(); + + private: + void Visit(const ir::Load* expr, Expr* op) override; + + //! Replace the 'Load' node on the tensor to 'Load' node of its producers. + Expr ReplaceInlinedTensor(Expr* load); +}; + +/*! + * \brief Helper to inline a block into the its producer + * The derived class implements the following functionalities: + * 1) Substitute `Load` on the tensor to be inlined + * to its value calculation in the producer block + * 2) Analyze the producer block to determine the remapping of index variables + */ +class ReverseComputeInliner : public BaseInliner { + public: + explicit ReverseComputeInliner(const Tensor& inlined_tensor, + const Expr& inlined_store, + const Expr& inlined_load, + const Expr& target_store) + : BaseInliner(inlined_tensor, inlined_store), inlined_load_(inlined_load), target_store_(target_store) {} + + bool BodyPatternAllowInline(); + + protected: + Expr inlined_load_{nullptr}; + Expr target_store_{nullptr}; + + private: + void Visit(const ir::Load* expr, Expr* op) override; + void Visit(const ir::Store* expr, Expr* op) override; + + //! Replace the 'Load' node on the tensor to 'Store' node of its consumers. + Expr ReplaceInlinedTensor(Expr* load); + Expr ReplaceTargetTensor(Expr* store); +}; + +// The struct used to remove the original block in ComputeAt. +class LeafBlockRemovalPlan : public ir::IRMutator<> { + public: + LeafBlockRemovalPlan(const Expr& block, Expr* source_expr, Expr* target_expr) + : block_(block), source_expr_(source_expr), target_expr_(target_expr) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize* expr, Expr* op) override { + if (*op == block_) { + find_block = true; + return; + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::For* expr, Expr* op) override { + if (*op == block_) { + find_block = true; + return; + } + IRMutator::Visit(expr, op); + } + + void Visit(const ir::Block* expr, Expr* op) override { + if (expr->stmts.size() > 1U) { + int block_index = -1; + for (int i = 0; i < expr->stmts.size(); ++i) { + auto keep_flag = find_block; + find_block = false; + auto* node = op->As(); + IRMutator::Visit(&node->stmts[i], &node->stmts[i]); + if (find_block) { + if (depth == 0) { + *source_expr_ = *op; + block_index = i; + } + depth++; + } + find_block = find_block || keep_flag; + } + if (block_index != -1) { + std::vector new_stmts; + for (int i = 0; i < expr->stmts.size(); ++i) { + if (i == block_index) + continue; + else + new_stmts.push_back(expr->stmts[i]); + } + auto target_block = ir::Block::Make(new_stmts); + *target_expr_ = target_block; + } + } else { + IRMutator::Visit(expr, op); + } + } + + private: + bool find_block{false}; + int depth{0}; + const Expr& block_; + Expr* source_expr_; + Expr* target_expr_; +}; + +class ComputeInlineChecker : public ir::IRMutator<> { + public: + ComputeInlineChecker(IRSchedule& schedule, Expr& block) : ir_schedule_(schedule), block_(block) {} + + bool Check(); + + void BuildDataDependency(); + + private: + void Visit(const ir::Load* expr, Expr* op) { + // Check there is Load Expr corresponds to Store Expr + if ((store_.As()->tensor).as_tensor_ref()->name == expr->tensor.as_tensor_ref()->name) { + should_skip_ = false; + return; + } + IRMutator::Visit(expr, op); + } + + private: + IRSchedule& ir_schedule_; + Expr& block_; + + Expr store_; + bool should_skip_{true}; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule_util.cc b/paddle/cinn/ir/ir_schedule_util.cc new file mode 100644 index 0000000000000..054e05dee06d3 --- /dev/null +++ b/paddle/cinn/ir/ir_schedule_util.cc @@ -0,0 +1,1038 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_schedule_util.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/lang/compute.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace ir { + +Tensor GetTensor(const Expr& block) { + CHECK(block.As()); + auto find_tensor = ir::CollectIRNodesWithoutTensor( + block, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + CHECK((*find_tensor.begin()).As()->tensor.as_tensor()); + Tensor tensor = (*find_tensor.begin()).As()->tensor.as_tensor_ref(); + return tensor; +} + +Tensor GetReadTensor(const Expr& block, int index) { + CHECK(block.As()); + auto find_tensor = ir::CollectIRNodesWithoutTensor( + block, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; + std::vector res; + auto find_read_tensor = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { + if (x->As()) res.push_back(x->As()->tensor.as_tensor_ref()); + return x->As(); + }); + CHECK_EQ(find_read_tensor.size(), res.size()); + CHECK(!find_read_tensor.empty()) << "Didn't find Load tensor in block!"; + CHECK_LT(index, (int)find_read_tensor.size()) << "Index is not < read tensor's size!"; + return res[index]; +} + +int GetLoopExtent(const Expr& loop) { + CHECK(loop.As()); + CHECK(common::is_zero(loop.As()->min)); + CHECK(loop.As()->extent.is_constant()); + return (int)loop.As()->extent.get_constant(); +} + +void SetCudaAxisInfo(Expr* lowered_func) { + if (!lowered_func->as_lowered_func()) { + LOG(ERROR) << "The input of SetCudaAxisInfo should be lowered_func!"; + return; + } + + auto func_body = lowered_func->as_lowered_func_ref()->body; + CudaAxisInfo info; + + auto block_nodes = ir::CollectIRNodes(func_body, [&](const Expr* x) { + if (x->As() && x->As()->bind_info().valid()) { + auto bind_info = x->As()->bind_info(); + info.set_valid(true); + if (bind_info.for_type == ForType::GPUThread) { + CHECK(common::is_zero(x->As()->min)); + CHECK(x->As()->extent.is_constant()); + int range = x->As()->extent.get_constant(); + range = range > info.block_dim(bind_info.offset) ? range : info.block_dim(bind_info.offset); + VLOG(3) << "Set block dim[" << bind_info.offset << "] with range " << range; + info.set_block_dim(bind_info.offset, range); + } else if (bind_info.for_type == ForType::GPUBlock) { + CHECK(common::is_zero(x->As()->min)); + CHECK(x->As()->extent.is_constant()); + int range = x->As()->extent.get_constant(); + range = range > info.grid_dim(bind_info.offset) ? range : info.grid_dim(bind_info.offset); + info.set_grid_dim(bind_info.offset, range); + VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range " << range; + } else { + LOG(FATAL) << "The for loop's bind info should be gpu block or thread!"; + } + } + return (x->As() && x->As()->bind_info().valid()); + }); + lowered_func->as_lowered_func_ref()->cuda_axis_info = info; +} + +bool Contains(const Expr& container, const Expr& expr) { + auto find_expr = ir::CollectIRNodesWithoutTensor( + container, [&](const Expr* x) { return (x->node_type() == expr.node_type() && *x == expr); }, true); + return (!find_expr.empty()); +} + +Expr GetNextForLoop(const Expr& for_loop) { + Expr result; + CHECK(for_loop.As()) << "The input of GetNextForLoop should be ir::For!"; + Expr for_body = for_loop.As()->body; + ir::Block* for_body_block = for_body.As(); + CHECK(for_body_block) << "The for_loop's body shoule be Block!"; + + // Only support for body block contains a sub for loop + int next_idx = -1; + for (int i = 0; i < for_body_block->stmts.size(); ++i) { + Expr stmt = for_body_block->stmts[i]; + if (stmt.As() || stmt.As()) { + if (next_idx == -1) { + next_idx = i; + } else { + // More then one sub for loop, Return undefined. + return result; + } + } + } + if (next_idx == -1) { + // More then one sub for loop, Return undefined. + return result; + } + + Expr block_body = for_body_block->stmts[next_idx]; + if (block_body.As()) { + // TODO(zhhsplendid): is it right to only handle true case? + // It may be wrong, but the code is written by previous developer, for us, + // we will check it later in the future. + CHECK(block_body.As()->true_case.As()); + Expr true_case = block_body.As()->true_case; + if (true_case.As()->stmts.size() != 1U || !true_case.As()->stmts[0].As()) + return result; + result = true_case.As()->stmts[0]; + return result; + } else if (block_body.As()) { + return block_body; + } else { + return result; + } +} + +std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom) { + std::vector if_nodes; + CHECK(top.As()); + CHECK(bottom.As()); + for (auto loop_iter = top; loop_iter != bottom;) { + CHECK(loop_iter.As()); + CHECK(loop_iter.As()->body.As()) << "For node's body should be Block!"; + auto block = loop_iter.As()->body.As(); + for (Expr tmp : block->stmts) { + if (tmp.As()) { + if_nodes.push_back(tmp); + CHECK(tmp.As()->true_case.As()); + Expr true_case = tmp.As()->true_case; + CHECK(true_case.As()->stmts.size() == 1U && true_case.As()->stmts[0].As()); + tmp = true_case.As()->stmts[0]; + } + if (tmp.As()) { + loop_iter = tmp; + } + } + } + return if_nodes; +} + +void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates) { + CHECK_EQ(replaced.size(), candidates.size()) + << "In ReplaceExpr, the size of Vars to be replaced must be equal to the size of cadidate Exprs! Please check."; + if (replaced.empty()) return; + std::map replacing_map; + for (int i = 0; i < replaced.size(); ++i) { + // If the Var to be replaced is equal to the candidate, we skip it. + if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) continue; + replacing_map[replaced[i]] = candidates[i]; + } + MappingVarToExprMutator mapper(replacing_map); + mapper(source); + return; +} + +std::vector ValidateFactors(const std::vector& factors, int total_extent) { + CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; + bool has_minus_one = false; + int product = 1; + for (auto& i : factors) { + CHECK(i != 0) << "The params in factors of Split should not be 0! Please check."; + CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check."; + if (i == -1) { + CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check."; + has_minus_one = true; + } else { + product *= i; + } + } + std::vector validated_factors = factors; + if (!has_minus_one) { + CHECK_GE(product, total_extent) + << "In Split, the factors' product should be equal to original loop's extent! Please check."; + return validated_factors; + } else { + CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= " + "original loop's extent! Please check."; + int minus_one_candidate = (int)ceil((double)total_extent / (double)product); + for (int i = 0; i < validated_factors.size(); ++i) { + if (validated_factors[i] == -1) { + validated_factors[i] = minus_one_candidate; + } + } + return validated_factors; + } +} + +void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) { + auto* rf_for = rf_loop.As(); + CHECK(rf_for) << "Expr param of Rfactor must be For node! Please check."; + // check the rf_loop only has one schedule block + auto block_nodes = ir::CollectIRNodesWithoutTensor( + rf_loop, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(block_nodes.size(), 1U) << "Rfactor Loop should only have one schedule block"; + auto find_store = ir::CollectIRNodesWithoutTensor( + rf_loop, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_store.size(), 1U); + auto indice = find_store.begin()->As()->indices; + // check rf_axis + CHECK_LE(rf_axis, indice.size()) << "rf_axis should not be greater than store's domain size"; + // check rfactor loop is reduce + auto* sch_block_realize = block_nodes.begin()->As(); + auto* sch_block = sch_block_realize->schedule_block.As(); + CHECK(sch_block); + auto& iter_values = sch_block_realize->iter_values; + auto& iter_vars = sch_block->iter_vars; + CHECK_EQ(iter_values.size(), iter_vars.size()); + auto rf_loop_var = rf_for->loop_var; + Var rf_block_var; + for (int i = 0; i < iter_values.size(); ++i) { + if (ContainVar({iter_values[i]}, rf_loop_var->name)) { + CHECK(!rf_block_var.defined()) << "rfactor loop var can only be binded to one block var"; + auto iter_value = iter_values[i].As<_Var_>(); + CHECK(iter_value) << "not support complex reduce bindings"; + rf_block_var = iter_vars[i]; + auto it = std::find_if(indice.begin(), indice.end(), [&](const Expr& x) { + return x.As<_Var_>() && x.As<_Var_>()->name == rf_block_var->name; + }); + CHECK(it == indice.end()) << "rfactor loop var is not reduce, please check!"; + } + } +} + +std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root) { + auto loop_nodes = + ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { return x->As() && Contains(*x, expr); }); + std::vector result(loop_nodes.begin(), loop_nodes.end()); + if (result.empty()) LOG(FATAL) << "Didn't find expr's : \n" << expr << "\n loops in root : \n" << root; + std::sort(result.begin(), result.end(), [&](Expr i, Expr j) { + return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); + }); + return result; +} + +IterRange GetAccessedRange(const Expr& index, + const std::vector& iter_vars, + const std::vector& iter_ranges) { + CHECK_EQ(iter_vars.size(), iter_ranges.size()); + std::vector var_mins, var_maxs; + for (const auto& range : iter_ranges) { + var_mins.emplace_back(range.min); + var_maxs.emplace_back(range.min + range.extent - 1); + } + + Expr indice_min = optim::IRCopy(index); + Expr indice_max = optim::IRCopy(index); + // replace the var by the corresponding iter_value + ReplaceExpr(&indice_min, iter_vars, var_mins); + ReplaceExpr(&indice_max, iter_vars, var_maxs); + // simplify expression + indice_min = common::AutoSimplify(indice_min); + indice_max = common::AutoSimplify(indice_max); + + Expr indice_extent; + Expr mod_extent(0); + if (indice_min.As() && indice_min.As()->b().is_constant()) mod_extent = indice_min.As()->b(); + + if (indice_min == indice_max) { + if (common::is_zero(mod_extent)) { + // If a index keeps constant, its extent should be 1. + indice_extent = Expr(1); + } else { + indice_extent = mod_extent; + } + } else { + indice_extent = common::AutoSimplify(common::AutoSimplify(indice_max) - common::AutoSimplify(indice_min) + 1); + } + + if (indice_extent.is_constant() && indice_extent.get_constant() < 0) { + VLOG(3) << "deduced indices are not constant"; + indice_min = indice_max; + indice_extent = Expr(-indice_extent.get_constant()); + } + VLOG(3) << "indice_min=" << indice_min << ", indice_max=" << indice_max << ", indice_extent=" << indice_extent; + return IterRange(indice_min, indice_extent); +} + +std::vector CalculateTensorRegions(const Expr& block, + const std::vector& tensor_indices, + const Tensor& tensor, + const Expr& root) { + CHECK(block.As()); + auto iter_vars = block.As()->schedule_block.As()->iter_vars; + auto iter_values = block.As()->iter_values; + + std::vector loop_vars; + std::vector loop_ranges; + + auto outer_loops = GetLoopsOfExpr(block, root); + for (auto& loop : outer_loops) { + CHECK(loop.As()); + loop_vars.emplace_back(loop.As()->loop_var); + loop_ranges.emplace_back(IterRange(loop.As()->min, loop.As()->extent)); + } + + std::vector result; + for (int i = 0; i < tensor_indices.size(); ++i) { + Expr binded_index = optim::IRCopy(tensor_indices[i]); + ReplaceExpr(&binded_index, iter_vars, iter_values); + auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges); + + // in generally, the range should be constant, but in some cases our AutoSimplify + // (algebraic simplification function) can't simplify completely where we use the whole + // shape in this indice as the accessed range conservatively + if (!range.min.is_constant() || !range.extent.is_constant()) { + VLOG(3) << "deduced range is not constant, range.min=" << range.min << ", range.extent=" << range.extent; + if (tensor->buffer.defined()) { + CHECK_GT((int)tensor->buffer->shape.size(), i); + result.emplace_back(IterRange(Expr(0), tensor->buffer->shape[i])); + } else { + CHECK_GT((int)tensor->shape.size(), i); + result.emplace_back(IterRange(Expr(0), tensor->shape[i])); + } + } else { + result.emplace_back(std::move(range)); + } + } + + return result; +} + +Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) { + CHECK(block.As()); + auto compute_body = block.As()->schedule_block.As()->body; + if (is_write) { + std::vector find_store_vec; + auto find_store = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + if (x->As()) find_store_vec.push_back(*x); + return x->As(); + }); + CHECK_EQ(find_store.size(), find_store_vec.size()); + CHECK_LT(index, (int)find_store.size()); + Expr store_index = find_store_vec[index]; + return store_index; + } else { + std::vector find_load_vec; + auto find_load = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { + if (x->As()) find_load_vec.push_back(*x); + return x->As(); + }); + CHECK_EQ(find_load.size(), find_load_vec.size()); + CHECK_LT(index, (int)find_load.size()); + Expr load_index = find_load_vec[index]; + return load_index; + } +} + +Tensor MakeCacheTensor(const Tensor& tensor, const std::string& memory_type) { + auto cache_tensor = lang::Compute( + tensor->shape, + [=](const std::vector& dims) { return tensor(dims); }, + tensor->name + "_" + memory_type + "_temp_buffer"); + cache_tensor->WithBuffer(memory_type); + return cache_tensor; +} + +Expr MakeCacheBlock(const std::vector& buffer_ranges, + CacheBlockInfo* info, + const std::string& memory_type, + DeviceAPI device_api) { + // loop variables + std::vector loop_vars; + // bindings in block realize + std::vector iter_values; + // Create loop vars and block vars' binding_value + for (const auto& range : buffer_ranges) { + Var loop_var(common::UniqName("cache_ax" + std::to_string(loop_vars.size()))); + // Var loop_var("ax" + std::to_string(loop_vars.size())); + loop_vars.push_back(loop_var); + iter_values.push_back(common::AutoSimplify(range.min + loop_var)); + } + // block variables + std::vector block_vars; + Tensor new_tensor = info->alloc; + // Create block vars, block's accessed region and accessing indices + CHECK(new_tensor->buffer.defined()); + for (auto& dim : new_tensor->buffer->shape) { + Var var(Expr(0), dim, "v" + std::to_string(block_vars.size()), false); + block_vars.push_back(var); + } + auto body = new_tensor->tensor_store_expanded_body(); + std::vector axis_vars = common::GenDefaultAxis(new_tensor->domain.size()); + axis_vars.insert(axis_vars.end(), new_tensor->reduce_axis.begin(), new_tensor->reduce_axis.end()); + for (int i = 0; i < axis_vars.size(); ++i) { + optim::ReplaceVarWithExpr(&body, axis_vars[i], block_vars[i]); + } + Expr block = ir::ScheduleBlockRealize::Make( + iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, new_tensor->name, Block::Make({body}))); + Expr new_body = block; + for (int i = (int)loop_vars.size() - 1; i >= 0; i--) { + new_body = For::Make(loop_vars[i], + Expr(0), + common::AutoSimplify(buffer_ranges[i].extent), + ir::ForType::Serial, + device_api, + ir::Block::Make({new_body})); + } + info->cache_block = std::move(new_body); + return block; +} + +void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write) { + Expr find_tensor = is_write ? Expr(info->write_tensor) : Expr(info->read_tensor); + auto find_produce_read = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->As() && x->As()->tensor == find_tensor; }); + + if (find_produce_read.empty()) { + CHECK(root.As()->schedule_block.As()); + CHECK(root.As()->schedule_block.As()->body.As()); + info->loc_block = root.As()->schedule_block.As()->body; + info->loc_pos = 0; + return; + } + + CHECK_EQ(find_produce_read.size(), 1U); + Expr producer = *(find_produce_read.begin()); + + CHECK(root.As()->schedule_block.As()); + CHECK(root.As()->schedule_block.As()->body.As()); + info->loc_block = root.As()->schedule_block.As()->body; + for (int i = 0; i < (int)info->loc_block.As()->stmts.size(); ++i) { + if (Contains(info->loc_block.As()->stmts[i], producer)) { + info->loc_pos = i + 1; + break; + } + } +} + +const std::set CollectLoopsToSet(const std::vector& loops) { + std::set for_loops; + for (auto& i : loops) { + CHECK(i.As()) << "loops should be For node! Please check."; + auto inserted = for_loops.insert(i); + if (!inserted.second) { + LOG(FATAL) << "There should be no duplicate elements in loops! Please check."; + } + } + return for_loops; +} + +// This function is used in Reorder schedule primitive. Since input loop +// Expr(s) of Reorder doesn't give original for loop order, we have to +// find the top (most outter) loop and bottom (most inner) among loop Expr(s) +std::pair GetBoundaryOfReorderRange(const std::set& loop_set) { + Expr top = *loop_set.begin(); + Expr bottom; + std::set visited; + bool first_traversal = true; + for (Expr loop_i : loop_set) { + if (visited.count(loop_i)) { + continue; + } + Expr v_for = loop_i; + CHECK(v_for.As()); + while (v_for.defined()) { + // If loop_i's sub loop is visited it must be pre-visited top. + // Then loop_i should be the new top + if (visited.count(v_for)) { + if (v_for != top) { + LOG(FATAL) << "Loops in GetBoundaryOfReorderRange is not a chain! Please check."; + } + top = loop_i; + break; + } + + // This while loop always GetNextForLoop(sub loop), so the last + // visited v_for in the first traversal will be the bottom. + if (first_traversal && loop_set.count(v_for)) { + bottom = v_for; + } + visited.insert(v_for); + v_for = GetNextForLoop(v_for); + } + first_traversal = false; + } + CHECK(top.As()); + CHECK(bottom.defined()); + CHECK(bottom.As()); + return std::make_pair(top, bottom); +} + +std::vector GetLoopsInRange(const Expr& top, const Expr& bottom) { + std::vector chain; + CHECK(top.As()); + CHECK(bottom.As()); + for (auto loop_iter = top; loop_iter != bottom;) { + Expr tmp = GetNextForLoop(loop_iter); + if (!tmp.defined()) LOG(FATAL) << "Loops in GetLoopsInReorderRange is not a chain! Please check."; + chain.push_back(loop_iter); + loop_iter = tmp; + } + chain.push_back(bottom); + return chain; +} + +// Construct a loop chain such that: +// +// loops[i_1] { +// loops[i_2] { +// ... +// loops[i_n] { +// stmts; +// } +// } +// } +// +// where reordered_indices = {i_1, i_2, ... i_n } +// +// This is a helper function which constructs non-main chain for other body +// statements in Reorder. See comment and call place in ConstructNewLoopChain +Expr ConstructOtherStmtChain(const std::vector& stmts, + const std::vector& loops, + const std::vector reordered_indices) { + Expr new_loop; + for (int i = reordered_indices.size() - 1; i >= 0; --i) { + Expr temp = optim::IRCopy(loops[reordered_indices[i]]); + CHECK(temp.defined()); + CHECK(temp.As()); + if (new_loop.defined()) { + temp.As()->body = Block::Make({new_loop}); + } else { + temp.As()->body = Block::Make({stmts}); + } + new_loop = temp; + } + return new_loop; +} + +Expr ConstructNewLoopChain(const std::vector& chain, + const std::vector& ordered_loops, + const std::set& loop_set, + std::vector& if_nodes) { + std::vector> condition_vars; + // In each IfThenElse node, find the vars its condition depends on. + for (auto& if_expr : if_nodes) { + CHECK(if_expr.As()); + auto var_set = ir::CollectIRNodes(if_expr.As()->condition, [&](const Expr* x) { return x->as_var(); }); + std::set var_name_set; + for (auto& i : var_set) var_name_set.insert(i.as_var()->name); + condition_vars.push_back(var_name_set); + } + Expr new_loop; + int index = static_cast(ordered_loops.size()) - 1; + + std::vector reordered_loop_chain; + // Construct the main loop chain from bottom to top. + for (int i = static_cast(chain.size()) - 1; i >= 0; i--) { + auto& loop_in_chain = chain[i]; + CHECK(loop_in_chain.As()); + Expr temp; + if (loop_set.count(loop_in_chain)) { + CHECK_GE(index, 0); + temp = optim::IRCopy(ordered_loops[index]); + --index; + } else { + temp = optim::IRCopy(loop_in_chain); + } + CHECK(temp.defined()); + CHECK(temp.As()); + // Main chain, each loop's body only contains sub_loop or bottom loop's body + if (new_loop.defined()) { + temp.As()->body = Block::Make({new_loop}); + } else { + temp.As()->body = loop_in_chain.As()->body; + } + Expr original_temp = temp; + // Here we handle the IfThenElse nodes. + for (int i = 0; i < static_cast(if_nodes.size()); ++i) { + if (condition_vars[i].count(original_temp.As()->loop_var->name)) { + Expr temp_body = temp.As()->body; + if (temp_body.As() && temp_body.As()->stmts.size() == 1U) + temp_body = temp_body.As()->stmts[0]; + temp.As()->body = IfThenElse::Make( + if_nodes[i].As()->condition, temp_body, if_nodes[i].As()->false_case); + temp.As()->body = Block::Make({temp.As()->body}); + if_nodes.erase(if_nodes.begin() + i); + condition_vars.erase(condition_vars.begin() + i); + i--; + } + } + new_loop = temp; + reordered_loop_chain.push_back(new_loop); + } + CHECK(new_loop.defined()); + + // new_loop_chain, which represents the main loop chain, now is from top to bottom. + std::reverse(reordered_loop_chain.begin(), reordered_loop_chain.end()); + + // In the main loop chain, each loop's body only contains sub_loop or bottom + // loop's body, but the origin loop chain may contain some other body stmts. + // The main loop chain lost those other body stmts. + // For example: + // + // for (i, 0, 32) { Reorder j, i for (j, 0, 64) { + // other_body_stmts above main chine + // for (j, 0, 64) { ------------------> for (i, 0, 32) { + // bottom_loop_body bottom_loop_body + // } } + // } } + // + // We go throuph origin loop and check other body stmts, adding it as another + // chain, such as: + // + // for (i, 0, 32) { + // other_body_stmts + // } + // for (j, 0, 64) { + // for (i, 0, 32) { + // bottom_loop_body + // } + // } + // + + // Construct the complete loop chain from origin loop top to bottom. + CHECK_EQ(chain.size(), reordered_loop_chain.size()) + << "origin loop chain size not equals reordered requirement when ConstructNewLoopChain in Reorder"; + std::unordered_set origin_loop_var_names; + Expr ret = new_loop; + + // Maintain an index to add stmt (other body stmt chain) + // + // stmt stmt MainChainLoop stmt stmt + // index index+1 + // + // The index of this MainChainLoop points the place before next MainChainLoop + // We can insert statements before MainChainLoop at the index, and insert + // statements after MainChainLoop at the index + 1 + int add_other_chain_index = 0; + + for (int i = 0; i < chain.size() - 1; ++i) { + // we just check i < chain.size() - 1 + // because bottom loop's body stmts have been all added + + const ir::For* loop_in_chain = chain[i].As(); + ir::For* reordered_in_chain = reordered_loop_chain[i].As(); + + origin_loop_var_names.insert(loop_in_chain->loop_var->name); + CHECK_EQ(origin_loop_var_names.size(), i + 1) << "Duplicate loop var name in origin Chain during Reorder"; + + const ir::Block* body_block = loop_in_chain->body.As(); + + if (body_block != nullptr && body_block->stmts.size() > 1) { + // contains other body stmts + + // Get the other body statements before loop and after loop + bool other_stmt_body_before_loop = true; + std::vector stmts_before_loop; + std::vector stmts_after_loop; + for (int j = 0; j < body_block->stmts.size(); ++j) { + if (body_block->stmts[j].As() && + body_block->stmts[j].As()->loop_var->name == chain[i + 1].As()->loop_var->name) { + other_stmt_body_before_loop = false; + continue; + } + if (other_stmt_body_before_loop) { + stmts_before_loop.push_back(body_block->stmts[j]); + } else { + stmts_after_loop.push_back(body_block->stmts[j]); + } + } + + // Find the chain that other body stmts shares with main loop chain + std::vector reordered_indices; + for (int j = 0; j < reordered_loop_chain.size(); ++j) { + if (origin_loop_var_names.count(reordered_loop_chain[j].As()->loop_var->name)) { + reordered_indices.push_back(j); + } + } + CHECK_EQ(reordered_indices.size(), origin_loop_var_names.size()) + << "Reordered chain loop var names doesn't match other stmt chain loop var names"; + + // Add other stmts chain to root Block if other stmts exist + if (!stmts_before_loop.empty()) { + Expr before_chain = ConstructOtherStmtChain(stmts_before_loop, reordered_loop_chain, reordered_indices); + if (ret.As() == nullptr) { + ret = ir::Block::Make({ret}); + } + std::vector& inplace_stmts = ret.As()->stmts; + auto pos = inplace_stmts.begin() + add_other_chain_index; + inplace_stmts.insert(pos, before_chain); + ++add_other_chain_index; + } + + if (!stmts_after_loop.empty()) { + Expr after_chain = ConstructOtherStmtChain(stmts_after_loop, reordered_loop_chain, reordered_indices); + if (ret.As() == nullptr) { + ret = ir::Block::Make({ret}); + } + std::vector& inplace_stmts = ret.As()->stmts; + auto pos = inplace_stmts.begin() + add_other_chain_index + 1; + inplace_stmts.insert(pos, after_chain); + } + } + } + + return ret; +} + +std::vector GetProducers(const Expr& block, const Expr& root) { + CHECK(block.As()); + CHECK(root.As()); + std::vector producers; + + // collect all producers' tensor names + std::set producer_tensor_names; + auto compute_body = block.As()->schedule_block.As()->body; + ir::CollectIRNodesWithoutTensor(compute_body, [&producer_tensor_names](const Expr* x) { + auto* load = x->As(); + if (load) { + producer_tensor_names.insert(load->tensor.as_tensor()->name); + return true; + } + return false; + }); + + // traverse each of other blocks and filter those ones which contain at least one producer tensor; + auto find_blocks = ir::CollectIRNodesWithoutTensor( + root, [&block, &root](const Expr* x) { return x->As() && *x != block && *x != root; }); + for (auto&& cur : find_blocks) { + auto* cur_block = cur.As()->schedule_block.As(); + CHECK(cur_block) << "block result should be a ScheduleBlockRealize"; + auto find_stores = ir::CollectIRNodesWithoutTensor(cur_block->body, [&producer_tensor_names](const Expr* x) { + return x->As() && producer_tensor_names.count(x->As()->tensor.as_tensor()->name) > 0; + }); + if (!find_stores.empty()) producers.emplace_back(cur); + } + return producers; +} + +std::vector GetConsumers(const Expr& block, const Expr& root) { + CHECK(block.As()); + CHECK(root.As()); + std::vector consumers; + std::string block_tensor = GetTensor(block)->name; + auto find_block = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->As() && *x != block && *x != root; }); + for (auto& i : find_block) { + CHECK(i.As()->schedule_block.As()); + auto block_body = i.As()->schedule_block.As()->body; + auto find_load = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && x->As()->tensor.as_tensor_ref()->name == block_tensor; + }); + if (!find_load.empty()) consumers.emplace_back(i); + } + return consumers; +} + +void CheckComputeAtValidation(const Expr& block, const Expr& loop, const Expr& root) { + auto find_block = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->As() && *x == block; }, true); + CHECK(!find_block.empty()) << "Didn't find block in root!"; + + auto find_loop = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->As() && *x == loop; }, true); + CHECK(!find_loop.empty()) << "Didn't find loop in root!"; + + auto find_block_in_loop = ir::CollectIRNodesWithoutTensor( + loop, [&](const Expr* x) { return x->As() && *x == block; }, true); + CHECK(find_block_in_loop.empty()) << "loop should not be block's ancestor!"; +} + +void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { + CHECK(for_loop.As()); + CHECK(for_loop.As()->body.As()); + ir::Block* dst_block = for_loop.As()->body.As(); + CHECK(index == -1 || index >= 0 && index < dst_block->stmts.size()) + << "index = " << index << ", it should be -1 or between [0, block stmts size)"; + + if (index == -1) { + dst_block->stmts.emplace_back(insertion); + } else { + auto dst_it = dst_block->stmts.begin() + index; + if (dst_it->As()) { + auto* inserted_block = dst_it->As()->true_case.As(); + CHECK(inserted_block) << "the IfThenElse node to be inserted shuold contain a true_case block"; + inserted_block->stmts.insert(inserted_block->stmts.begin(), insertion); + } else { + dst_block->stmts.insert(dst_it, insertion); + } + } +} + +IterRange RangeUnion(const IterRange& range1, const IterRange& range2) { + Expr new_min = common::AutoSimplify(Min::Make(range1.min, range2.min)); + Expr new_extent = common::AutoSimplify( + common::AutoSimplify(Max::Make(range1.min + range1.extent, range2.min + range2.extent)) - new_min); + return IterRange(new_min, new_extent); +} + +std::vector CalculateRequiredRegions(const Expr& block, + const Expr& loop, + const Expr& root, + const std::vector& required_blocks, + bool is_store_provided) { + CHECK(block.As()) << "Param block should be a ir::ScheduleBlockRealize node"; + CHECK(loop.As()) << "Param loop should be a ir::For node"; + + std::set provided_nodes; + if (is_store_provided) { + provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); + } else { + provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); + } + + std::vector required_buffer_range; + // deduce accessed regions of the provided tensor in block by itering each required block + for (const Expr& pro_node : provided_nodes) { + const std::string& provided_tensor_name = is_store_provided ? pro_node.As()->tensor.as_tensor()->name + : pro_node.As()->tensor.as_tensor()->name; + + for (const Expr& req_block : required_blocks) { + CHECK(req_block.As()); + Expr block_body = + optim::IRCopy(req_block.As()->schedule_block.As()->body); + auto iter_vars = req_block.As()->schedule_block.As()->iter_vars; + auto iter_values = req_block.As()->iter_values; + ReplaceExpr(&block_body, iter_vars, iter_values); + + // Notice that we look for For nodes in loop's body instead of loop itself. + auto find_loops = ir::CollectIRNodesWithoutTensor( + loop.As()->body, [&](const Expr* x) { return x->As() && Contains(*x, req_block); }); + + // collect vars and their ranges of each loop under the input loop + std::vector loop_vars; + std::vector loop_ranges; + for (const auto& for_loop : find_loops) { + loop_vars.emplace_back(for_loop.As()->loop_var); + loop_ranges.emplace_back(for_loop.As()->min, for_loop.As()->extent); + } + + std::set required_nodes; + if (is_store_provided) { + required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; + }); + } else { + required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { + return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; + }); + } + + // deducing range by indices of each required node + for (const Expr& req_node : required_nodes) { + const auto& indices = is_store_provided ? req_node.As()->indices : req_node.As()->indices; + + if (find_loops.empty()) { + for (int i = 0; i < indices.size(); ++i) { + if (i >= required_buffer_range.size()) + required_buffer_range.emplace_back(indices[i], Expr(1)); + else + required_buffer_range[i] = RangeUnion(required_buffer_range[i], IterRange(indices[i], Expr(1))); + } + } else { + for (int i = 0; i < indices.size(); ++i) { + auto range = GetAccessedRange(indices[i], loop_vars, loop_ranges); + if (i >= required_buffer_range.size()) { + required_buffer_range.emplace_back(std::move(range)); + } else { + required_buffer_range[i] = RangeUnion(required_buffer_range[i], range); + } + } + } + } // end for load_nodes + } + } + + int iter_size = block.As()->iter_values.size(); + // maybe some dimensions are not accessed by consumers so we should append them + if (iter_size > required_buffer_range.size()) { + for (int i = required_buffer_range.size(); i < iter_size; ++i) { + CHECK(block.As()->iter_values[i].as_var() || + block.As()->iter_values[i].is_constant()); + if (block.As()->iter_values[i].as_var()) { + auto find_for_loops = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { + return x->As() && x->As()->loop_var->name == + block.As()->iter_values[i].as_var_ref()->name; + }); + CHECK_EQ(find_for_loops.size(), 1U); + required_buffer_range.emplace_back((*find_for_loops.begin()).As()->min, + (*find_for_loops.begin()).As()->extent); + } else { + int cons = (int)block.As()->iter_values[i].is_constant(); + required_buffer_range.emplace_back(Expr(cons), Expr(1)); + } + } + } + return required_buffer_range; +} + +Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root) { + CHECK(schedule_block.As()); + auto compute_body = schedule_block.As()->schedule_block.As()->body; + // 1. Check the schedule block to be inlined is not a reduce tensor. + auto find_store = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_store.size(), 1U); + Expr tensor = (*find_store.begin()).As()->tensor; + CHECK(!tensor.as_tensor_ref()->is_reduce_tensor()); + // 2. Check this schedule block is the only writer of the tensor. + find_store = ir::CollectIRNodesWithoutTensor( + root, + [&](const Expr* x) { + return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; + }, + true); + CHECK_EQ(find_store.size(), 1U); + // 3. Check there is no overlap between the buffers the schedule block reads and writes. + auto find_load = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + CHECK(find_load.empty()); + return (*find_store.begin()); +} + +std::tuple CheckReverseComputeInlineValidationAndGetExprs(const Expr& schedule_block, + const Expr& root) { + CHECK(schedule_block.As()); + auto compute_body = schedule_block.As()->schedule_block.As()->body; + // 1. Check the schedule block to be reverse inlined is not a reduce tensor. + auto find_inlined_load = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_inlined_load.size(), 1U); + Expr tensor = (*find_inlined_load.begin()).As()->tensor; + CHECK(!tensor.as_tensor_ref()->is_reduce_tensor()); + auto inlined_load = *find_inlined_load.begin(); + // 2. Check this schedule block is the only reader of the tensor. + auto find_load = ir::CollectIRNodesWithoutTensor( + root, + [&](const Expr* x) { + return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; + }, + true); + CHECK_EQ(find_load.size(), 1U); + // 3. Check there is no overlap between the buffers the schedule block reads and writes. + auto find_store = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + CHECK(find_store.empty()); + // 4. Get store that will be inlined. + auto find_inlined_store = ir::CollectIRNodesWithoutTensor( + root, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); + CHECK_EQ(find_inlined_store.size(), 1U); + auto inlined_store = *find_inlined_store.begin(); + // 5. Get target store. + auto find_target_store = ir::CollectIRNodesWithoutTensor( + compute_body, [&](const Expr* x) { return x->As(); }, true); + CHECK_EQ(find_target_store.size(), 1U); + auto target_store = *find_target_store.begin(); + return {inlined_load, inlined_store, target_store}; +} + +bool ContainVar(const std::vector& exprs, const std::string& var_name) { + for (auto& expr : exprs) { + auto find_expr = ir::CollectIRNodesWithoutTensor( + expr, [&](const Expr* x) { return x->As<_Var_>() && x->As<_Var_>()->name == var_name; }, true); + if (!find_expr.empty()) return true; + } + return false; +} + +std::unordered_map PrimeFactorize(int n) { + std::unordered_map factors; + while (n % 2 == 0) { + ++factors[2]; + n /= 2; + } + for (int i = 3; i <= sqrt(n); i += 2) { + while (n % i == 0) { + ++factors[i]; + n /= i; + } + } + if (n > 2) { + factors[n] = 1; + } + return factors; +} + +std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, int n, int extent) { + std::vector tile; + while (n > 1) { + std::unordered_map factors = PrimeFactorize(extent); + int product = 1; + for (auto& factor : factors) { + if (factor.second >= 1) { + int num = utils::SampleUniformInt(1, factor.second + 1, rand_seed); + product *= std::pow(factor.first, num); + } + } + tile.push_back(product); + extent /= product; + --n; + } + tile.push_back(extent); + return tile; +} +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_schedule_util.h b/paddle/cinn/ir/ir_schedule_util.h new file mode 100644 index 0000000000000..12a80f637969c --- /dev/null +++ b/paddle/cinn/ir/ir_schedule_util.h @@ -0,0 +1,448 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/tensor.h" +#include "cinn/utils/random_engine.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +// Self-defined operator to support std::set +struct CompExpr { + bool operator()(const Expr& left, const Expr& right) const { + return utils::GetStreamCnt(left) < utils::GetStreamCnt(right); + } +}; + +// Self-defined operator to support std::set +struct CompVar { + bool operator()(const Var& left, const Var& right) const { return left->name < right->name; } +}; + +struct MappingVarToExprMutator : public ir::IRMutator<> { + MappingVarToExprMutator(const std::map& replacing_map) : replacing_map_(replacing_map) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::_Var_* expr, Expr* op) override { + if (replacing_map_.count(op->as_var_ref())) { + *op = replacing_map_.at(op->as_var_ref()); + } + } + + private: + const std::map& replacing_map_; +}; + +struct FindLoopsVisitor { + FindLoopsVisitor(const Expr& block) : block_(block) {} + + std::vector operator()(const Expr* expr) { + CHECK(block_.As()); + visit_end = false; + Visit(expr); + return result; + } + + private: + void Visit(const Expr* expr) { + if (visit_end || !expr->defined()) return; + if (expr->As()) { + father_loops.emplace_back(*expr); + Visit(&(expr->As()->body)); + father_loops.pop_back(); + } else if (expr->As()) { + if (!expr->As()->iter_values.empty() && (*expr == block_)) { + result = father_loops; + visit_end = true; + return; + } else { + Visit(&(expr->As()->schedule_block)); + } + } else if (expr->As()) { + Visit(&(expr->As()->body)); + } else if (expr->As()) { + for (auto& n : expr->As()->stmts) Visit(&n); + } else if (expr->As()) { + Visit(&(expr->As()->true_case)); + Visit(&(expr->As()->false_case)); + } + } + + std::vector father_loops{}; + std::vector result{}; + bool visit_end{false}; + const Expr& block_; +}; + +/** + * \brief Given a ScheduleBlockRealize node, return the Store tensor in its body. + * @param block The given ScheduleBlockRealize node + * @return The Store tensor in block + */ +Tensor GetTensor(const Expr& block); + +struct FindBlocksVisitor { + FindBlocksVisitor(const std::string& block_name = "") : block_name_(block_name) {} + + std::vector operator()(const Expr* expr) { + Visit(expr); + return result; + } + + private: + void Visit(const Expr* expr) { + if (!expr->defined()) return; + if (!block_name_.empty() && !result.empty()) return; + if (expr->As()) { + Visit(&(expr->As()->body)); + } else if (expr->As()) { + if (!expr->As()->iter_values.empty()) { + auto* schedule_block = expr->As()->schedule_block.As(); + if (block_name_.empty() || schedule_block->name == block_name_) { + result.emplace_back(*expr); + } + } else { + Visit(&(expr->As()->schedule_block)); + } + } else if (expr->As()) { + Visit(&(expr->As()->body)); + } else if (expr->As()) { + for (auto& n : expr->As()->stmts) Visit(&n); + } else if (expr->As()) { + Visit(&(expr->As()->true_case)); + Visit(&(expr->As()->false_case)); + } + } + std::string block_name_; + std::vector result{}; +}; + +struct CacheBlockInfo { + /*! \brief The tensor to be read. */ + Tensor read_tensor; + /*! \brief The tensor to be written. */ + Tensor write_tensor; + /*! \brief The tensor allocation to be inserted into the block signature. */ + Tensor alloc; + /*! \brief The AST node whose body is where the cache stage should be inserted. */ + Expr loc_block; + /*! \brief The index to insert the cache_read/cache_write stage. */ + int loc_pos; + /*! \brief The cache_read/cache_write stage to be inserted. */ + Expr cache_block; +}; + +// a struct to present the min value and the extent of a iterable range, +// where it is represented as a semi-closed interval, i.e [min, min + extent) +struct IterRange { + IterRange(Expr begin, Expr length) : min(begin), extent(length) {} + + Expr min; + Expr extent; +}; + +/** + * \brief Given a ScheduleBlockRealize node, return the index-th Load tensor in its body. + * @param block The given ScheduleBlockRealize node + * @param index The index of Load tensor + * @return The index-th Load tensor in block + */ +Tensor GetReadTensor(const Expr& block, int index); + +/** + * \brief Given a For node, return its extent as int. + * @param loop The given For node + * @return The extent of For node + */ +int GetLoopExtent(const Expr& loop); + +/** + * \brief Given a vector of Exors, return whether they contain a var with specific name. + * @param exprs The given vector of Exprs + * @param var_name The name of specific var + * @return Whether there is a Var with the same name as var_name + */ +bool ContainVar(const std::vector& exprs, const std::string& var_name); + +/** + * \brief Given a _LoweredFunc_, set its cuda_axis_info based on its func_body. + * @param lowered_func A pointer to the given _LoweredFunc_ + */ +void SetCudaAxisInfo(Expr* lowered_func); + +/*! + * \brief Check if a Expr node contains a ScheduleBlockRealize node. + * \param container The container Expr node. + * \param expr The node we want to find. + * \return If the container contains the expr. + */ +bool Contains(const Expr& container, const Expr& expr); + +/** + * \brief Given a For loop, return the next For loop in its body. + * @param for_loop The given For loop. + * @return The next For loop. + */ +Expr GetNextForLoop(const Expr& for_loop); + +/** + * \brief Given two For loops, return all ir::IfThenElse nodes between them. + * @param top The given top For loop. + * @param bottom The given bottom For loop. + * @return All ir::IfThenElse nodes between them. + */ +std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom); + +/** + * Replace Vars in replaced to Exprs in candidates in source. Vars -> Exprs is one-to-one correspondence. + * @param source The Expr we will implement the change. + * @param replaced The Vars to be replaced. + * @param candidates The Exprs to replace Vars in replaced. + */ +void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates); + +/** + * Validate the factors param of Split. We will check if factors are validate and change -1 to positive integer. + * @param factors The original factors. + * @param total_extent The extent of the loop to be splitted. + * @return return The valiated factors. + */ +std::vector ValidateFactors(const std::vector& factors, int total_extent); + +void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis); + +/** + * Return loops that contain the expr. + * @param expr The expr. + * @param root The root of the whole AST. + * @return return Loops in AST that contain the expr. + */ +std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root); + +/** + * Given an index Expr and all vars' range, return the accessed range in this indice. + * @param index The Expr of a specified indice. + * @param iter_vars The vars in expr. + * @param iter_range Each var's range. + * @return return an IterRange represents the accessed range of this indice, If it is not constant, return corresponding + * tensor's shape. + */ +IterRange GetAccessedRange(const Expr& index, + const std::vector& iter_vars, + const std::vector& iter_ranges); + +/** + * Given a ScheduleBlockRealize, an AST root, a tensor and its tensor_indices, return the accessed buffer region of the + * tensor in block. + * @param block The ScheduleBlockRealize. + * @param tensor_indices The tensor's indices. + * @param tensor The tensor. + * @param root The root of whole AST. + * @return return The accessed buffer region of the tensor in block. + */ + +std::vector CalculateTensorRegions(const Expr& block, + const std::vector& tensor_indices, + const Tensor& tensor, + const Expr& root); + +/** + * Return n-th access tensor in block + * @param block The ScheduleBlockRealize. + * @param index The index indicating which tensor we want to get. + * @param is_write We want to get write tensor or read tensor. + * @return return The n-th access tensor in block. Should be ir::Store(is_write) or ir::Load(!is_write). + */ +Expr GetNthAccessExpr(const Expr& block, int index, bool is_write); + +/** + * Make a tensor's cache tensor. + * @param tensor The original tensor. + * @param memory_type The memory type of the cache tensor. + * @return return The tensor's cache tensor. + */ +Tensor MakeCacheTensor(const Tensor& tensor, const std::string& memory_type); + +/** + * Make a the cache tensor's block. + * @param buffer_region The accessed region of cache tensor. + * @param info The information of cache block. + * @param memory_type The memory type of cache tensor. + * @param device_api The device api of this Expr. + * @return return ScheduleBlockRealize of the cache tensor. + */ +Expr MakeCacheBlock(const std::vector& buffer_ranges, + CacheBlockInfo* info, + const std::string& memory_type, + DeviceAPI device_api); + +/** + * Fidn cache tensor block's insertion point in the whole AST(root). + * @param root The whole AST. + * @param info The information of cache block. + * @param is_write Are we inserting a write cache tensor or a read cache tensor. + */ +void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write); + +/** + * \brief Given a vector of For loops, return a set of them. + * @param loops The given vector of For loops. + * @return A set containing all the For loops in loops. + */ +const std::set CollectLoopsToSet(const std::vector& loops); + +/** + * \brief Given a set of For loops, return the boundary among them. + * @param loop_set The given set of For loops. + * @return A pair of the boundary among For loops.(The top For and bottom For) + */ +std::pair GetBoundaryOfReorderRange(const std::set& loop_set); + +/** + * \brief Given two For loops, return all loops between them. + * @param top The top For loop. + * @param bottom The bottom For loop. + * @return A vector containing all For loops between the boundary, stored in ascending order. + */ +std::vector GetLoopsInRange(const Expr& top, const Expr& bottom); + +/** + * \brief Given params, construct a new loop. + */ +Expr ConstructNewLoopChain(const std::vector& chain, + const std::vector& ordered_loops, + const std::set& loop_set, + std::vector& if_nodes); + +/*! + * \brief Find producers of block in root. + * \param block The ScheduleBlockRealize node we want to find its producers. + * \param root The root ScheduleBlockRealize node. + * \return block's producers(ScheduleBlockRealize nodes) in root. + */ +std::vector GetProducers(const Expr& block, const Expr& root); + +/*! + * \brief Find consumers of block in root. + * \param block The ScheduleBlockRealize node we want to find its consumers. + * \param root The root ScheduleBlockRealize node. + * \return block's consumers(ScheduleBlockRealize nodes) in root. + */ +std::vector GetConsumers(const Expr& block, const Expr& root); + +/*! + * \brief Check if the params of ComputeAt is validate. + * \param block The block node we want to move in ComputeAt. + * \param loop The for node we want to put the block under in ComputeAt. + * \param root The root ScheduleBlockRealize node of block and loop. + */ +void CheckComputeAtValidation(const Expr& block, const Expr& loop, const Expr& root); + +/*! + * \brief Insert a new ScheduleBlockRealize in a loop's body(under its IfThenElse Node, if any) + * \param for_loop The for loop whose body we want to modify + * \param insertion The ScheduleBlockRealize we want to insert + * \param index The position index of the for_loop body `stmts` to be inserted: + * - `index = -1` means inserted into the tail + * - otherwise, it should be a index between [0, stmts size) + */ +void InsertBlock(Expr& for_loop, const Expr& insertion, int index = 0); + +/*! + * \brief Make a union of two range. The detailed function is : + * new_range.min = min(range1.min, range2.min) + * new_range.extent = max(range1.min + range1.extent, range2.min + range2.extent) - new_range.min + * Notice that the pair indicates a range's min and extent. + * \param range1 The first range + * \param range2 The second range + * \return The union of these two ranges + */ +IterRange RangeUnion(const IterRange& range1, const IterRange& range2); + +/*! + * \brief Calculate the required buffer region given a block and its required blocks. + * For example, if block is : + * B[i0, j0] = A[i0, j0] + * loop is : + * for (i, 0, 64) { + * for (j, 0, 64) { + * C[i, j] = B[i, j] + * } + * } + * And required_blocks is : + * C[i, j] = B[i, j] + * Then we get the required B's region: + * B[i, j], where: + * i : [i, i] + * j : [0, 64] + * \param block The ScheduleBlockRealize node begin required + * \param loop The loop where we will insert the block under it + * @param root The root of the whole AST. + * \param required_blocks vector of ScheduleBlockRealize nodes that require the block + * \param is_store_provided Whether Store nodes of the block provide the tensor, + * true means it is in compute_at case, otherwise false means in reverse_compuate_at case + * \return Each index's range of block's tensor. Indicating the buffer region being required. + */ +std::vector CalculateRequiredRegions(const Expr& block, + const Expr& loop, + const Expr& root, + const std::vector& required_blocks, + bool is_store_provided = true); + +Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root); + +/*! + * \brief Check if the reverse compute inline validation passes for a given schedule block and root expression, + * and retrieve the store expression if so. + * Reverse compute inline validation ensures that the outputs of a loop nest are properly computed in reverse order. + * \param schedule_block The schedule block to check. + * \param root The root expression of the loop nest. + * \return A tuple containing the load that will be inlined, the store that will be inlined and the target store. + */ +std::tuple CheckReverseComputeInlineValidationAndGetExprs(const Expr& schedule_block, + const Expr& root); + +/*! + * \brief Get the prime factors of a number. + * For example, 12 = 2^2 * 3^1, then the return value is {2: 2, 3: 1}. + * \param n The number to be factorized. + * \return A map of prime factors and their corresponding exponents. + */ +std::unordered_map PrimeFactorize(int n); + +/*! + * \brief Given a number returns the form of the product of its n factors + * For example: + * n = 2, dividend = 12, return one of {2, 6}, {6, 2}, {3, 4}, {4, 3} + * \param seed The random number generator to use. + * \param n The number to be factorized. + * \param dividend The dividend of the number. + */ +std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, int n, int dividend); +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_test.cc b/paddle/cinn/ir/ir_test.cc new file mode 100644 index 0000000000000..39ec6b0073f58 --- /dev/null +++ b/paddle/cinn/ir/ir_test.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir.h" + +#include + +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +TEST(Expr, basic) { + Expr a(1); + auto b = Expr(a); + LOG(INFO) << b.as_int32(); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_verify.cc b/paddle/cinn/ir/ir_verify.cc new file mode 100644 index 0000000000000..b9f3fc7226e14 --- /dev/null +++ b/paddle/cinn/ir/ir_verify.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_verify.h" + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" + +namespace cinn::ir { + +struct IrVerifyVisitor : public ir::IRMutator<> { + using ir::IRMutator<>::Visit; + +#define __(op__) \ + void Visit(const op__ *op, Expr *expr) override { \ + op->Verify(); \ + IRMutator::Visit(op, expr); \ + } + NODETY_FORALL(__) +#undef __ +}; + +void IrVerify(Expr e) { + IrVerifyVisitor visitor; + visitor.Visit(&e, &e); +} + +} // namespace cinn::ir diff --git a/paddle/cinn/ir/ir_verify.h b/paddle/cinn/ir/ir_verify.h new file mode 100644 index 0000000000000..fa2fe259ef127 --- /dev/null +++ b/paddle/cinn/ir/ir_verify.h @@ -0,0 +1,22 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn::ir { + +void IrVerify(Expr e); + +} // namespace cinn::ir diff --git a/paddle/cinn/ir/ir_verify_test.cc b/paddle/cinn/ir/ir_verify_test.cc new file mode 100644 index 0000000000000..5fcfe4cc8dcef --- /dev/null +++ b/paddle/cinn/ir/ir_verify_test.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_verify.h" + +#include + +#include "cinn/ir/ir_operators.h" + +namespace cinn::ir { + +TEST(IrVerify, basic) { + Expr a(1); + Expr b(1); + IrVerify(a + b); +} + +} // namespace cinn::ir diff --git a/paddle/cinn/ir/ir_visitor.cc b/paddle/cinn/ir/ir_visitor.cc new file mode 100644 index 0000000000000..0cdbc828a91a2 --- /dev/null +++ b/paddle/cinn/ir/ir_visitor.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/ir_visitor.h" + +#include + +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/tensor.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +bool operator==(Expr a, Expr b) { + if (a.get() == b.get()) return true; + // TODO(Superjomn) implement with a more accurate one + return utils::GetStreamCnt(a) == utils::GetStreamCnt(b); +} + +bool operator!=(Expr a, Expr b) { return !(a == b); } + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/ir_visitor.h b/paddle/cinn/ir/ir_visitor.h new file mode 100644 index 0000000000000..21d7bab369ae8 --- /dev/null +++ b/paddle/cinn/ir/ir_visitor.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace ir { + +struct _Tensor_; + +/** + * Base class of all the methods visit the IR tree. + * @param RetTy return type. + * @param Args type of the extra arguments passed to the all the methods. + */ +template +struct IRVisitorBase { + //! Visit a expression. + // @{ + virtual RetTy Visit(const ir::Expr* expr, Args... args) { + CHECK(expr->defined()); + switch (expr->node_type()) { +#define __(op__) \ + case ir::IrNodeTy::op__: \ + return Visit(expr->As(), args...); + + NODETY_FORALL(__) + + default: + LOG(FATAL) << "not supported NodeTy"; +#undef __ + } + return RetTy(); + } + // @} + + protected: +#define __(op__) virtual RetTy Visit(const ir::op__* op, Args... args) = 0; + NODETY_FORALL(__) +#undef __ +}; + +/** + * Base of all the Ir readonly visitor. + */ +struct IRVisitor : public IRVisitorBase { + IRVisitor() = default; + + void Visit(const Expr* x) { IRVisitorBase::Visit(x); } +#define __m(t__) \ + virtual void Visit(const t__* x) {} + NODETY_FORALL(__m) +#undef __m +}; + +// std::set CollectIRNodes(Expr expr, std::function teller); + +bool operator==(Expr a, Expr b); +bool operator!=(Expr a, Expr b); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/layout.cc b/paddle/cinn/ir/layout.cc new file mode 100644 index 0000000000000..9b97c0e5ecab2 --- /dev/null +++ b/paddle/cinn/ir/layout.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/layout.h" + +namespace cinn { +namespace ir { + +void Layout::Verify() { + { + CHECK(!name_.empty()); + CHECK(!axes_.empty()); + axis_names_ = ""; + for (auto& axis : axes_) { + CHECK_EQ(axis->name.size(), 1U); + auto axis_name = axis->name[0]; + CHECK((axis_name >= 'A' && axis_name <= 'Z') || (axis_name >= 'a' && axis_name <= 'z')); + CHECK(axis_names_.find(axis_name) == axis_names_.npos) << axis_name << " has already exsit."; + axis_names_ += axis_name; + } + int offset = 'A' - 'a'; + for (auto& axis : axes_) { + CHECK_EQ(axis->name.size(), 1U); + auto axis_name = axis->name[0]; + if (axis_name >= 'a' && axis_name <= 'z') { + CHECK(axis_names_.find(axis_name + offset) != axis_names_.npos) + << "sub-axis " << axis_name << " finds no primal axis"; + } + } + } +} +Layout::Layout(const std::string& name) { + CHECK(!name.empty()); + int factor = 0; + std::vector axes; + for (char c : name) { + if (c >= 'A' && c <= 'Z') { + CHECK_EQ(factor, 0) << "Invalid factor " << factor << " before primal axis " << c; + axes.push_back(ir::Var(std::string(1, c))); + } else if (c >= '0' && c <= '9') { + factor = 10 * factor + c - '0'; + } else if (c >= 'a' && c <= 'z') { + CHECK_GT(factor, 0) << "Invalid factor " << factor << " for sub-axis " << c; + axes.push_back(ir::Var(factor, std::string(1, c))); + factor = 0; + } else { + LOG(FATAL) << "Invalid layout: " << name; + } + } + name_ = name; + axes_ = axes; + Verify(); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/layout.h b/paddle/cinn/ir/layout.h new file mode 100644 index 0000000000000..1af93114c93bd --- /dev/null +++ b/paddle/cinn/ir/layout.h @@ -0,0 +1,48 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" + +namespace cinn { +namespace ir { +class Layout { + public: + std::string name_; + std::string axis_names_; + std::vector axes_; + + Layout(const std::string& name, const std::vector& axes) : name_(name), axes_(axes) { Verify(); } + + explicit Layout(const std::string& name); + + inline const std::string& name() const { return name_; } + // axis name without factor + inline const std::string& axis_names() const { return axis_names_; } + inline const std::vector& axes() const { return axes_; } + inline int ndims() const { return axes_.size(); } + inline const Var operator[](int i) const { return axes_[i]; } + inline const char axis_names(int i) const { return axis_names_[i]; } + + void Verify(); + Expr Make(const std::string& name, const std::vector& axes); +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc new file mode 100644 index 0000000000000..36b0dcf6014c8 --- /dev/null +++ b/paddle/cinn/ir/lowered_func.cc @@ -0,0 +1,472 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/lowered_func.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/optim/tensor_write_tell.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace ir { + +using common::bfloat16; +using common::float16; + +const _LoweredFunc_* LoweredFunc::operator->() const { return As<_LoweredFunc_>(); } +_LoweredFunc_* LoweredFunc::operator->() { return As<_LoweredFunc_>(); } + +LoweredFunc _LoweredFunc_::Make(const std::string& name, + const std::vector& args, + const Expr& body, + const std::vector& temp_bufs) { + auto* n = make_shared<_LoweredFunc_>(); + n->name = name; + n->args = args; + n->body = body; + n->temp_bufs = temp_bufs; + + n->CheckValid(); + n->PrepareAllocOutputBufferExprs(); + n->PrepareCreateTempBufferExprs(); + n->PrepareAllocTempBufferExprs(); + n->AllocTempBuffer(); + bool with_expr_gen_tensor = true; + if (FLAGS_cinn_ir_schedule) with_expr_gen_tensor = false; + n->PrepareBufferCastExprs(with_expr_gen_tensor); + n->PrepareArgumentExprs(); + n->PrepareDeallocTempBufferExprs(); + n->PrepareDeallocOutputBufferExprs(); + return LoweredFunc(n); +} + +void _LoweredFunc_::CheckValid() const { + // check there is at least one output + int out_count = 0; + int in_count = 0; + for (auto& arg : args) { + in_count += arg.is_input(); + out_count += arg.is_output(); + } + CHECK_GT(out_count, 0) << "At least one output argument is needed for a function\n" << body; +} + +std::vector _LoweredFunc_::expr_fields() { return {&body}; } +std::vector _LoweredFunc_::expr_fields() const { return {&body}; } + +void _LoweredFunc_::PrepareCudaAxisInfoFromBody() { + std::set bound_for_exprs = ir::CollectIRNodes(body, [](const Expr* expr) { + const ir::For* for_expr = expr->As(); + return for_expr != nullptr && for_expr->is_binded(); + }); + + if (bound_for_exprs.empty()) { + device_api = ir::DeviceAPI::GPU; + cuda_axis_info.set_grid_dim(0, 1); + cuda_axis_info.set_block_dim(0, 1); + cuda_axis_info.set_valid(true); + return; + } + + // bound_for_exprs.empty() is false + for (const Expr& expr : bound_for_exprs) { + const ir::For* for_expr = expr.As(); + if (for_expr->for_type() == ir::ForType::GPUBlock) { + cuda_axis_info.set_grid_dim(for_expr->bind_info().offset, for_expr->extent.as_int32()); + } else if (for_expr->for_type() == ir::ForType::GPUThread) { + cuda_axis_info.set_block_dim(for_expr->bind_info().offset, for_expr->extent.as_int32()); + } + } + device_api = ir::DeviceAPI::GPU; + cuda_axis_info.set_valid(true); +} + +void _LoweredFunc_::PrepareAllocOutputBufferExprs() { + CHECK(alloc_output_buffer_exprs.empty()) << "duplicate prepare the allocate buffer for outputs"; + + std::set buffer_names; + for (auto& arg : args) { + if (arg.is_output()) { + CHECK(arg.type().valid()) << "argument [" << arg.name() << "]'s type should be set"; + if (arg.is_buffer() && !buffer_names.count(arg.name())) { // only buffer need allocation. + buffer_names.insert(arg.name()); // Avoid duplicate + alloc_output_buffer_exprs.push_back( + Alloc::Make(arg.buffer_arg(), arg.buffer_arg()->type(), arg.buffer_arg()->shape, Expr(), Expr())); + } + } + } +} + +std::vector _LoweredFunc_::PrepareAllocTempBufferExprs() const { + std::vector alloc_temp_buffer_exprs; + for (auto& temp_buf : temp_bufs) { + if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { + alloc_temp_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); + } + } + return alloc_temp_buffer_exprs; +} + +std::vector _LoweredFunc_::PrepareDeallocTempBufferExprs() const { + std::vector dealloc_temp_buffer_exprs; + for (auto& temp_buf : temp_bufs) { + if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { + dealloc_temp_buffer_exprs.push_back(Free::Make(temp_buf)); + } + } + return dealloc_temp_buffer_exprs; +} + +std::vector _LoweredFunc_::PrepareCreateTempBufferExprs() const { + std::vector create_temp_buffer_exprs; + for (auto& temp_buf : temp_bufs) { + if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { + auto expr = ir::intrinsics::BufferCreate::Make(temp_buf); + auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); + Var variable = ir::_Var_::Make(temp_buf->name, buffer_ptr_type); + expr = ir::Let::Make(variable, expr); + create_temp_buffer_exprs.push_back(expr); + } + } + return create_temp_buffer_exprs; +} + +std::vector _LoweredFunc_::CudaPrepareAllocTempBufferExprs() const { + std::vector alloc_output_buffer_exprs; + for (auto temp_buf : temp_bufs) { + if (utils::Startswith(temp_buf->name, "_")) { + temp_buf->name = temp_buf->name.substr(1); + } + if (!temp_buf->shape.empty() && temp_buf->type() != Void()) { + alloc_output_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr())); + } + } + return alloc_output_buffer_exprs; +} + +void _LoweredFunc_::PrepareDeallocOutputBufferExprs() { + CHECK(dealloc_output_buffer_exprs.empty()) << "duplicate prepare the allocate buffer for outputs"; + + std::set buffer_names; + for (auto& arg : args) { + if (arg.is_output()) { + CHECK(arg.type().valid()) << "argument [" << arg.name() << "]'s type should be set"; + if (arg.is_buffer() && !buffer_names.count(arg.name())) { // only buffer need allocation. + buffer_names.insert(arg.name()); // Avoid duplicate + dealloc_output_buffer_exprs.push_back(Free::Make(arg.buffer_arg())); + } + } + } +} + +void _LoweredFunc_::AllocTempBuffer() {} + +void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { + buffer_data_cast_exprs.clear(); + // collect write. + optim::TensorWriteTeller write_teller; + write_teller.Collect(&body); + + auto tensors = CollectAllTensorReference(with_expr_gen_tensor); + std::sort(tensors.begin(), tensors.end(), [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); + + VLOG(3) << "Function used " << tensors.size() << " buffers"; + for (auto& tensor : tensors) { + auto* node = tensor.As(); + CHECK(node); + if (!tensor->buffer.defined()) continue; + + Type value_type = tensor->type().ElementOf(); + bool is_const = !write_teller.IsWrite(tensor->name); + value_type.set_cpp_handle(); + value_type.set_cpp_const(is_const); + Var variable = _Var_::Make(tensor->name, value_type); + + Expr body = is_const ? ir::intrinsics::BufferGetDataConstHandle::Make(tensor->buffer) + : ir::intrinsics::BufferGetDataHandle::Make(tensor->buffer); + + Type target_type = is_const ? tensor->buffer->dtype.PointerOf().ConstOf() : tensor->buffer->dtype.PointerOf(); + body = ir::Cast::Make(target_type, body); + auto let = Let::Make(variable, body); + + buffer_data_cast_exprs.push_back(let); + } +} + +std::vector _LoweredFunc_::CudaAliasVarExprs() const { + std::unordered_set args_buffer; + for (auto arg : args) { + args_buffer.insert(arg.name()); + } + // collect write. + std::vector res; + optim::TensorWriteTeller write_teller; + write_teller.Collect(&body); + + auto tensors = CollectAllTensorReference(); + std::sort(tensors.begin(), tensors.end(), [](const Tensor& a, const Tensor& b) { return a->name < b->name; }); + + for (auto& tensor : tensors) { + auto* node = tensor.As(); + CHECK(node); + if (!tensor->buffer.defined()) { + continue; + } + if (tensor->name == tensor->buffer->name.substr(1) || args_buffer.count(tensor->buffer->name) == 0) { + continue; + } + Type value_type = tensor->type().ElementOf(); + bool is_const = !write_teller.IsWrite(tensor->name); + value_type.set_cpp_handle(); + value_type.set_cpp_const(is_const); + Var variable = _Var_::Make(tensor->name, value_type); + Var body = Var(tensor->buffer->name.substr(1), value_type); + + auto let = Let::Make(variable, body); + + res.push_back(let); + } + return res; +} + +void _LoweredFunc_::PrepareArgumentExprs() { + // Seems a CINN func. + if (args.front().is_var() && args.front().var_arg()->type() == type_of()) return; + + // type of `void*` + auto void_ptr_array_type = Type().with_type(Type::type_t::Void).set_cpp_handle(); + // type of `cinn_buffer_t*` + auto buffer_ptr_type = Type().set_customized_type(common::customized_type::kbuffer_t).set_cpp_handle(); + // type of `const cinn_buffer_t*` + auto const_buffer_ptr_type = buffer_ptr_type.with_cpp_const(); + CHECK(!buffer_ptr_type.is_cpp_const()); + + Var args_passed_in("_args", type_of()); + auto pod_value_ptr = common::CastIfNeeded(args_passed_in, type_of()); + + if (FLAGS_cinn_runtime_display_debug_info) { + argument_prepare_exprs.push_back(runtime::IntrinsicCall( + Void(), runtime::intrinsic::print_debug_args_repr, {pod_value_ptr, common::make_const(Int(32), args.size())})); + } + + /* + * Get something like: + * + * const cinn_buffer_t* _A = args[0]; + * cinn_buffer_t* _B = (cinn_buffer_t*)args[1]; + * int M = (int)arg[2]; + */ + + // We just has two kinds of argument types, first is `cinn_buffer_t*`, second is `const cinn_buffer_t*`, do not need a + // `any` type support currently. + for (int i = 0; i < args.size(); i++) { + auto& arg = args[i]; + // cast arg to cinn_pod_value_t* + + // something like `_args[0]` + Expr load_expr = Load::Make(pod_value_ptr, {common::make_const(i)}); + CHECK_EQ(load_expr.type(), type_of()); + load_expr = ir::intrinsics::GetAddr::Make(load_expr); + + Var _arg; + bool is_const = arg.is_input(); + + if (arg.is_buffer()) { + auto buffer_type = is_const ? const_buffer_ptr_type : buffer_ptr_type; + _arg = Var(arg.name(), buffer_type); + } else if (arg.is_var()) { + _arg = Var(arg.name(), arg.var_arg()->type()); + } else { + CINN_NOT_IMPLEMENTED + } + + CHECK(_arg->type().valid()); + + Expr pod_cast_expr; + + if (arg.is_buffer()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else if (arg.type() == type_of()) { + pod_cast_expr = ir::intrinsics::PodValueToX::Make(load_expr, type_of()); + } else { + LOG(ERROR) << "Not supported type [" << arg.type() << "]"; + CINN_NOT_IMPLEMENTED + } + + Expr let_expr = Let::Make(_arg, pod_cast_expr); + CHECK(let_expr.type().valid()); + argument_prepare_exprs.push_back(let_expr); + } +} + +std::vector _LoweredFunc_::CollectAllTensorReference(bool with_expr_gen_tensor) const { + std::set tensor_exprs = + with_expr_gen_tensor + ? ir::CollectIRNodes(body, [](const Expr* expr) { return expr->As(); }) + : ir::CollectIRNodesWithoutTensor(body, [](const Expr* expr) { return expr->As(); }); + + std::vector tensors; + // remove the duplicate tensor by their name. + std::set names; + + for (const Expr& expr : tensor_exprs) { + Expr& _expr = *const_cast(&expr); + Tensor b(_expr.As<_Tensor_>()); + if (names.count(b->name)) continue; + tensors.push_back(b); + names.insert(b->name); + } + + return tensors; +} + +ir::Buffer Argument::buffer_arg() const { + CHECK(is_buffer()); + return buffer_arg_; +} + +ir::Var Argument::var_arg() const { + CHECK(is_var()); + return var_arg_; +} + +void Argument::set_buffer(const ir::Buffer& x) { + CHECK(!is_var()) << "the buffer is already a var"; + buffer_arg_ = x; +} + +void Argument::set_var(const ir::Var& x) { + CHECK(!is_buffer()) << "the buffer is already a buffer"; + var_arg_ = x; +} + +Argument::Argument(const ir::Buffer& buffer, Argument::IO io) { + set_buffer(buffer); + this->io = io; +} + +Type Argument::type() const { + if (is_var()) + return var_arg()->type(); + else if (is_buffer()) + return buffer_arg()->type(); + else + CINN_NOT_IMPLEMENTED +} + +std::string Argument::name() const { + if (is_buffer()) + return buffer_arg()->name; + else if (is_var()) + return var_arg()->name; + else + CINN_NOT_IMPLEMENTED + return ""; +} + +Argument::Argument(const ir::Var& var, Argument::IO io) { + set_var(var); + this->io = io; +} + +std::string Argument::human_readable() const { + std::stringstream os; + os << ""; + return os.str(); +} + +std::ostream& operator<<(std::ostream& os, const CudaAxisInfo& x) { + os << ""; + os << ""; + return os; +} + +void CudaAxisInfo::set_grid_dim(int offset, int x) { + valid_ = true; + CHECK_LT(offset, 3); + grid_dims_[offset] = x; +} +void CudaAxisInfo::set_block_dim(int offset, int x) { + valid_ = true; + CHECK_LT(offset, 3); + block_dims_[offset] = x; +} +int CudaAxisInfo::grid_dim(int offset) const { + CHECK(valid_); + CHECK_LT(offset, 3); + return grid_dims_[offset]; +} +int CudaAxisInfo::block_dim(int offset) const { + CHECK(valid_); + CHECK_LT(offset, 3); + return block_dims_[offset]; +} +void CudaAxisInfo::ExtendWith(const CudaAxisInfo& other) { + set_valid(true); + for (int i = 0; i < 3; i++) { + grid_dims_[i] = std::max(grid_dims_[i], other.grid_dims_[i]); + block_dims_[i] = std::max(block_dims_[i], other.block_dims_[i]); + } +} +void CudaAxisInfo::CopyGridDimsTo(std::vector* dest) const { + dest->insert(dest->begin(), grid_dims_.begin(), grid_dims_.end()); +} +void CudaAxisInfo::CopyBlockDimsTo(std::vector* dest) const { + dest->insert(dest->begin(), block_dims_.begin(), block_dims_.end()); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/lowered_func.h b/paddle/cinn/ir/lowered_func.h new file mode 100755 index 0000000000000..f237232b1c7ab --- /dev/null +++ b/paddle/cinn/ir/lowered_func.h @@ -0,0 +1,198 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_base.h" + +namespace cinn { +namespace ir { + +class _LoweredFunc_; + +/** + * A struct representing an argument to a lowered function. Used for specifying the function signature of generated + * code. + */ +struct Argument { + //! Input or output. + enum class IO { kInput = 0, kOutput = 1 }; + + IO io{IO::kInput}; + + Argument() = default; + explicit Argument(const ir::Buffer& buffer, IO io = IO::kInput); + explicit Argument(const ir::Var& var, IO io = IO::kInput); + + //! Set the buffer argument, all the buffer information are stored in ir::Buffer. + void set_buffer(const ir::Buffer& x); + + //! Set the var argument. + void set_var(const ir::Var& x); + + bool is_input() const { return io == IO::kInput; } + bool is_output() const { return io == IO::kOutput; } + + bool is_var() const { return var_arg_.defined(); } + bool is_buffer() const { return buffer_arg_.defined(); } + bool defined() const { return is_var() || is_buffer(); } + + ir::Buffer buffer_arg() const; + ir::Var var_arg() const; + + //! The type of the buffer or scalar. + Type type() const; + + std::string name() const; + + std::string human_readable() const; + + private: + //! The buffer field. + ir::Buffer buffer_arg_; + //! The scalar field. + ir::Var var_arg_; +}; + +//! Wrapper for _LoweredFunc_ +class LoweredFunc : public IrNodeRef { + public: + LoweredFunc() = default; + explicit LoweredFunc(IrNode* n) : IrNodeRef(n) {} + + operator Expr() const { return Expr(ptr()); } + + const _LoweredFunc_* operator->() const; + _LoweredFunc_* operator->(); +}; + +using dim3_t = std::array; +struct CudaAxisInfo { + CudaAxisInfo() { + for (int& v : grid_dims_) v = 1; + for (int& v : block_dims_) v = 1; + set_valid(false); + } + + void set_grid_dim(int offset, int x); + void set_block_dim(int offset, int x); + + int grid_dim(int offset) const; + int block_dim(int offset) const; + + void CopyGridDimsTo(std::vector* dest) const; + void CopyBlockDimsTo(std::vector* dest) const; + + inline void set_valid(bool x = false) { valid_ = x; } + inline bool valid() const { return valid_; } + + //! Extend the axis dims and keep the larger dims. + void ExtendWith(const CudaAxisInfo& other); + + private: + // the three dimensions represents x, y, z + dim3_t grid_dims_; + // the three dimensions represents x, y, z + dim3_t block_dims_; + bool valid_{false}; +}; + +std::ostream& operator<<(std::ostream& os, const CudaAxisInfo& x); + +/** + * Definition of a lowered function. Note that, it should be functional. + * + * Arguments of the function: + * + * both the input and output arguments, the output arguments are in the tail. + */ +struct _LoweredFunc_ : ExprNode<_LoweredFunc_> { + //! The name of this function. + std::string name; + + //! The Arguments used in the body of the function. + std::vector args; + + //! Temporary buffers(as output), these buffers will not appear in the function's argument list, but will be used in + //! the body. + std::vector temp_bufs; + + //! Body of this function. + Expr body; + + DeviceAPI device_api{DeviceAPI::UNK}; + + CudaAxisInfo cuda_axis_info; + + /** + * The output buffer will be resized to the size required, we leave all the expression here. + * The allocation and deallocation expressions will insert into the head and tail of the function's body. It supports + * lazy allocation/deallocation if the corresponding intristic methods support. + * + * Currently, we assume that all the input and output buffers should locate in heap, no other memory type is allowed. + */ + // @{ + std::vector alloc_output_buffer_exprs; + std::vector dealloc_output_buffer_exprs; + // @} + + //! something like: float* A_data = (float*)(A->memory); + std::vector buffer_data_cast_exprs; + + std::vector argument_prepare_exprs; + + static LoweredFunc Make(const std::string& name, + const std::vector& args, + const Expr& body, + const std::vector& temp_bufs); + + bool is_gpu_host() const { return cuda_axis_info.valid(); } + + void Verify() const override {} + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_; + + std::vector PrepareCreateTempBufferExprs() const; + //! Prepare the expressions for `alloc_tmp_buffer_exprs`. + std::vector PrepareAllocTempBufferExprs() const; + std::vector PrepareDeallocTempBufferExprs() const; + std::vector CudaPrepareAllocTempBufferExprs() const; + std::vector CudaAliasVarExprs() const; + void PrepareBufferCastExprs(bool with_expr_gen_tensor = true); + void PrepareCudaAxisInfoFromBody(); + + private: + void CheckValid() const; + //! Prepare the expressions for `alloc_output_buffer_exprs`. + void PrepareAllocOutputBufferExprs(); + //! Prepare the expressions for `dealloc_output_buffer_exprs`. + void PrepareDeallocOutputBufferExprs(); + //! Insert the allocation expr for temporary variables. + void AllocTempBuffer(); + + void PrepareArgumentExprs(); + //! Get all the Buffers the function body references. + //! NOTE it will return the buffers with duplicates removed(by comparing their name). + std::vector CollectAllTensorReference(bool with_expr_gen_tensor = true) const; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/module.cc b/paddle/cinn/ir/module.cc new file mode 100644 index 0000000000000..d0bd612bf0a7b --- /dev/null +++ b/paddle/cinn/ir/module.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/module.h" + +#include + +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/optimize.h" + +namespace cinn { +namespace ir { + +void Module::Builder::AddFunction(ir::LoweredFunc func) { + optim::Simplify(&(func->body)); + optim::SimplifyForLoops(&(func->body)); + optim::SimplifyBlocks(&(func->body)); + func->body = optim::Optimize(func->body, module_->target); + module_->functions.push_back(func); +} + +void Module::Builder::AddBuffer(ir::Buffer buffer) { + CHECK(buffer->target.defined()) << "buffer [" << buffer->name << "]'s target is undefined"; + if (std::find_if(module_->buffers.begin(), module_->buffers.end(), [&](const Expr &x) { + return x.as_buffer()->name == buffer->name; + }) == std::end(module_->buffers)) { + module_->buffers.push_back(buffer); + if (module_->target.arch == Target::Arch::X86) { + module_->buffers.back().as_buffer()->data_alignment = 32; + } + } +} + +void Module::Builder::Clear() { + module_->buffers.clear(); + module_->functions.clear(); + module_->submodules.clear(); +} + +Module Module::Builder::Build() { + if (module_->functions.empty()) { + VLOG(1) << "Module has no functions"; + } + + auto res = ir::Module(module_.get()); + + return optim::Optimize(res, module_->target); +} + +ir::_Module_ *Module::self() { return p_->as(); } +const ir::_Module_ *Module::self() const { return p_->as(); } + +const Target &Module::target() const { return self()->target; } + +const std::string &Module::name() const { return self()->name; } + +std::vector Module::buffers() const { + std::vector buffers; + for (auto &buffer : self()->buffers) { + buffers.emplace_back(buffer.as_buffer_ref()); + } + return buffers; +} + +std::vector Module::functions() const { + std::vector functions; + for (auto &x : self()->functions) { + functions.emplace_back(x.as_lowered_func_ref()); + } + return functions; +} + +std::vector Module::submodules() const { + std::vector modules; + for (auto &x : self()->submodules) { + modules.push_back(x.as_module_ref()); + } + return modules; +} + +void Module::Compile(const backends::Outputs &outputs) const {} + +Module::operator Expr() const { return Expr(ptr()); } + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/module.h b/paddle/cinn/ir/module.h new file mode 100644 index 0000000000000..e92df6f219801 --- /dev/null +++ b/paddle/cinn/ir/module.h @@ -0,0 +1,89 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/backends/outputs.h" +#include "cinn/common/common.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/lang/buffer.h" + +namespace cinn { + +namespace backends { +class CodeGenC; +} // namespace backends + +namespace ir { + +/** + * Module represents IR containing lowered function definitions and buffers. + */ +class Module : public ir::IrNodeRef { + public: + struct Builder { + Builder(const std::string& name, const Target& target) : module_(common::make_shared()) { + module_->name = name; + module_->target = target; + } + + void AddFunction(ir::LoweredFunc func); + void AddBuffer(ir::Buffer buffer); + void Clear(); + + Module Build(); + + private: + Shared module_; + }; + + //! Get the target of this module. + const Target& target() const; + + //! Get the name of the module. + const std::string& name() const; + + //! The members in the module. + // @{ + std::vector buffers() const; + std::vector functions() const; + std::vector submodules() const; + // @} + + //! Compile a module to some outputs. + void Compile(const backends::Outputs& outputs) const; + + ir::_Module_* self(); + const ir::_Module_* self() const; + + ir::_Module_* operator->() { return self(); } + const ir::_Module_* operator->() const { return self(); } + + operator Expr() const; + + protected: + Module(const std::string& name, const Target& target); + + explicit Module(ir::IrNode* n) : ir::IrNodeRef(n) {} + + friend class Module::Builder; + friend class backends::CodeGenC; + friend class ::cinn::ir::Expr; + friend class ::cinn::ir::_Module_; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/operation.cc b/paddle/cinn/ir/operation.cc new file mode 100644 index 0000000000000..217d0f853b762 --- /dev/null +++ b/paddle/cinn/ir/operation.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/operation.h" + +#include + +#include "cinn/common/common.h" + +namespace cinn { +namespace ir { + +Operation PlaceholderOp::Make(const std::string &name, const std::vector &shape, Type dtype) { + auto n = make_shared(); + n->name = name; + n->shape = shape; + n->set_type(dtype); + return Operation(n); +} + +const char *PlaceholderOp::func_type() const { return "placeholder_op"; } + +const char *ComputeOp::func_type() const { return "compute_op"; } + +Operation ComputeOp::Make(const std::string &name, + ComputeOp::handle_t handle, + const std::vector &shape, + const std::vector &domain, + const std::vector &reduce_axis, + const std::map &attrs, + const std::string &tag) { + auto n = make_shared(); + n->name = name; + n->producer_fn = handle; + n->shape = domain; + n->reduce_axis = reduce_axis; + n->tag = tag; + n->attrs = attrs; + auto axis = common::GenDefaultAxis(domain.size()); + std::vector _axis; + for (auto &x : axis) _axis.push_back(x); + n->body = {handle(_axis)}; + n->reduce_axis = reduce_axis; + return Operation(n); +} + +Operation CallOp::Make(const std::string &call_target, Expr call_op) { + auto n = make_shared(); + n->call_expr = call_op; + return Operation(n); +} + +Operation PrecedingViewOp::Make(const Tensor &tensor, int preceding_axis) { return Operation(); } + +const char *PrecedingViewOp::func_type() const { return PrecedingViewOp::__func_type__; } + +const char *CallOp::func_type() const { return __func_type__; } + +const char *ComputeOp::__func_type__ = "compute_op"; +const char *PlaceholderOp::__func_type__ = "placeholder_op"; +const char *CallOp::__func_type__ = "call_op"; + +const std::string &CallOp::target() const { + auto *call = call_expr.As(); + CHECK(call); + return call->name; +} +std::vector &CallOp::write_args() { + auto *call = call_expr.As(); + CHECK(call); + return call->write_args; +} +std::vector &CallOp::read_args() { + auto *call = call_expr.As(); + CHECK(call); + return call->read_args; +} +const std::vector &CallOp::write_args() const { + auto *call = call_expr.As(); + CHECK(call); + return call->write_args; +} +const std::vector &CallOp::read_args() const { + auto *call = call_expr.As(); + CHECK(call); + return call->read_args; +} +std::vector CallOp::args() const { + std::vector args; + auto &rargs = read_args(); + auto &wargs = write_args(); + args.insert(std::end(args), rargs.begin(), rargs.end()); + args.insert(std::end(args), wargs.begin(), wargs.end()); + return args; +} +const char *PrecedingViewOp::__func_type__ = "preceding_view_op"; + +const char *BufferShareOp::__func_type__ = "buffer_share_op"; +const char *BufferShareOp::func_type() const { return __func_type__; } + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/operation.h b/paddle/cinn/ir/operation.h new file mode 100644 index 0000000000000..be30969105356 --- /dev/null +++ b/paddle/cinn/ir/operation.h @@ -0,0 +1,130 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/tensor.h" + +namespace cinn { +namespace ir { + +/** + * @brief A placeholder op represents an input placeholder. + */ +struct PlaceholderOp : public _Operation_ { + //! The shape of the input. + std::vector shape; + //! The data type of the input. + Type dtype; + + static Operation Make(const std::string &name, const std::vector &shape, Type dtype); + + const char *func_type() const override; + + static char const *__func_type__; +}; + +struct CallOp : public _Operation_ { + const std::string &target() const; + + Expr call_expr; + + std::vector &read_args(); + std::vector &write_args(); + const std::vector &read_args() const; + const std::vector &write_args() const; + std::vector args() const; + + //! A reference to the target LoweredFunc if this CallOp calls an generated LoweredFunc. + Expr func; + + // the offset int the tuple of return values. + int value_slot{-1}; + + bool is_tuple_get{false}; + + //! Number of the value slots. + int num_value_slots{0}; + + CallOp() = default; + + static Operation Make(const std::string &call_target, Expr call_op); + + const char *func_type() const override; + + static char const *__func_type__; +}; + +/** + * The operation of the preceding view of a tensor. + */ +struct PrecedingViewOp : public _Operation_ { + Expr tensor; + + int preceding_axis{-1}; + + static Operation Make(const Tensor &tensor, int preceding_axis); + + const char *func_type() const override; + + static char const *__func_type__; +}; + +/** + * Share the same buffer. + */ +struct BufferShareOp : public _Operation_ { + const char *func_type() const override; + static Operation Make() { return Operation(new BufferShareOp); } + static char const *__func_type__; +}; + +/** + * @brief A Compute op that compute a tensor on certain domain. + */ +struct ComputeOp : public _Operation_ { + using handle_t = std::function &)>; + //! Var on each reduction axis, if the body is a Reduction. + std::vector reduce_axis; + //! Shape of the output. + std::vector shape; + //! The compute expression. + std::vector body; + //! The functor to generate the body, used to inline the expression if needed. + handle_t producer_fn; + + ComputeOp() = default; + + static Operation Make(const std::string &name, + ComputeOp::handle_t handle, + const std::vector &shape, + const std::vector &domain, + const std::vector &reduce_axis = {}, + const std::map &attrs = {}, + const std::string &tag = ""); + + const char *func_type() const override; + + static const char *__func_type__; +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/registry.cc b/paddle/cinn/ir/registry.cc new file mode 100644 index 0000000000000..2e8a7caf1efb1 --- /dev/null +++ b/paddle/cinn/ir/registry.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/registry.h" + +#include +#include // NOLINT + +namespace cinn::ir { +struct Registry::Manager { + static Manager *Global() { + static Manager manager; + return &manager; + } + + std::mutex mu; + std::map functions; + + private: + Manager() = default; + Manager(const Manager &) = delete; + void operator=(Manager &) = delete; +}; + +Registry &Registry::SetBody(lang::PackedFunc f) { + func_ = f; + return *this; +} + +Registry &Registry::SetBody(lang::PackedFunc::body_t f) { + func_ = lang::PackedFunc(f); + return *this; +} + +Registry::Registry(const std::string &name) : name_(name) {} + +/*static*/ Registry &Registry::Register(const std::string &name, bool can_override) { + auto *manager = Registry::Manager::Global(); + std::lock_guard lock(manager->mu); + if (manager->functions.count(name)) { + CHECK(can_override) << "Global PackedFunc[" << name << "] is already exists"; + } + + auto *r = new Registry(name); + manager->functions[name] = r; + return *r; +} + +/*static*/ bool Registry::Remove(const std::string &name) { + auto manager = Manager::Global(); + std::lock_guard lock(manager->mu); + auto it = manager->functions.find(name); + if (it != manager->functions.end()) { + manager->functions.erase(it); + return true; + } + return false; +} + +/*static*/ const lang::PackedFunc *Registry::Get(const std::string &name) { + auto *manager = Manager::Global(); + std::lock_guard lock(manager->mu); + auto *r = manager->functions[name]; + if (r) { + return &r->func_; + } + return nullptr; +} + +/*static*/ std::vector Registry::ListNames() { + auto *manager = Manager::Global(); + std::lock_guard lock(manager->mu); + std::vector keys; + for (const auto &_k_v_ : manager->functions) { + auto &k = std::get<0>(_k_v_); + auto &v = std::get<1>(_k_v_); + keys.push_back(k); + } + return keys; +} + +} // namespace cinn::ir diff --git a/paddle/cinn/ir/registry.h b/paddle/cinn/ir/registry.h new file mode 100644 index 0000000000000..612213a95d9cc --- /dev/null +++ b/paddle/cinn/ir/registry.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "cinn/lang/packed_func.h" + +namespace cinn::ir { + +class Registry { + public: + Registry &SetBody(lang::PackedFunc f); + Registry &SetBody(lang::PackedFunc::body_t f); + + static Registry &Register(const std::string &name, bool can_override = false); + static bool Remove(const std::string &name); + static const lang::PackedFunc *Get(const std::string &name); + static std::vector ListNames(); + + struct Manager; + + explicit Registry(const std::string &); + + protected: + std::string name_; + lang::PackedFunc func_; + friend class Manager; +}; + +} // namespace cinn::ir diff --git a/paddle/cinn/ir/schedule_desc.cc b/paddle/cinn/ir/schedule_desc.cc new file mode 100644 index 0000000000000..cb50cc2ab9614 --- /dev/null +++ b/paddle/cinn/ir/schedule_desc.cc @@ -0,0 +1,680 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/schedule_desc.h" + +#include + +#include +#include +#include + +#include "cinn/common/macros.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace ir { + +// ------ Following codes are about `Apply` functions registry of variaous types of ScheduleDesc::Step +class PackedStepContext; +// uniformed function prototype of a scheduling operation in IRSchedule +using StepApplyFunc = std::vector (*)(PackedStepContext*); + +// format the inputs, attrs, uniformed function of a scheduling step +class StepKindInfo { + public: + // compatible for Registry::EntryType + std::string name; + + // format: {"", "", ...} + StepKindInfo& Inputs(std::vector&& inputs) { + inputs_ = inputs; + return *this; + } + // format: {"", "", ...} + StepKindInfo& Attrs(std::vector&& attrs) { + attrs_ = attrs; + return *this; + } + // format: APPLY_FUNC_UNIFORM(...) + StepKindInfo& SetApplyFn(StepApplyFunc&& func) { + apply_func_ = func; + return *this; + } + + // execute the Apply function of this type + std::vector Apply(PackedStepContext* context) const { return apply_func_(context); } + + private: + friend class PackedStepContext; + + std::vector inputs_; + std::vector attrs_; + StepApplyFunc apply_func_{nullptr}; +}; + +// StepKindInfo register for all scheduling steps +class StepKindRegistry : public Registry { + public: + StepKindRegistry() = default; + + private: + CINN_DISALLOW_COPY_AND_ASSIGN(StepKindRegistry); +}; + +// PackedStepContext is the param of a uniformed `Apply` function, which is used to be an +// auxiliary structure to interact with in/out arguments of the original scheduling function in IRSchedule +class PackedStepContext { + public: + explicit PackedStepContext(const ScheduleDesc::Step& desc, const StepKindInfo* step_kind, IRSchedule* schedule) + : ir_schedule_(schedule) { + Build(desc, step_kind); + } + + // get the pointer of current IRSchedule object + IRSchedule* ScheduleHandler() const { return ir_schedule_; } + + // get the idx-th input whose signature is Expr + Expr InputAt(size_t idx) const { + CHECK_LT(idx, input_range_.size()) << "idx overranges"; + const auto& range = input_range_.at(idx); + CHECK(range.second - range.first == 1) << "not single param"; + return inputs_[range.first]; + } + + // get the idx-th input whose signature is `std::vector` + std::vector InputsAt(size_t idx) const { + CHECK_LT(idx, input_range_.size()) << "idx overranges"; + const auto& range = input_range_.at(idx); + std::vector results; + for (size_t s = range.first; s < range.second; ++s) { + results.emplace_back(inputs_[s]); + } + return results; + } + + // get the idx-th attribute value with correct type + template + const AttrType& AttrAt(size_t idx) const { + try { + return absl::get(attrs_.at(idx)); + } catch (absl::bad_variant_access& ex) { + LOG(FATAL) << "Attribute cast error, idx:" << idx << ", get tpye:" << typeid(AttrType).name() + << ", real index:" << attrs_.at(idx).index(); + throw ex; + } + } + + private: + void Build(const ScheduleDesc::Step& desc, const StepKindInfo* step_kind) { + // build inputs + size_t input_idx = 0; + for (auto&& param_name : step_kind->inputs_) { + auto arg_it = desc.inputs.find(param_name); + CHECK(arg_it != desc.inputs.end()) << "Can't find param:" << param_name; + auto&& args = arg_it->second; + inputs_.insert(inputs_.end(), std::make_move_iterator(args.begin()), std::make_move_iterator(args.end())); + input_range_.emplace_back(input_idx, input_idx + args.size()); + input_idx += args.size(); + } + + // build attrs + size_t attr_idx = 0; + for (auto&& attr_name : step_kind->attrs_) { + auto attr_it = desc.attrs.find(attr_name); + CHECK(attr_it != desc.attrs.end()) << "Can't find attribute:" << attr_name; + attrs_.emplace_back(attr_it->second); + ++attr_idx; + } + } + + IRSchedule* ir_schedule_; + std::vector inputs_; + std::vector> input_range_; + std::vector attrs_; +}; + +#define CINN_SPECIALIZE_ApplyCallHelper(attr_type) \ + template \ + struct ApplyCallHelper { \ + template \ + static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { \ + using rf_attr_type = std::remove_reference::type; \ + using rc_attr_type = std::remove_const::type; \ + const auto& arg = ctx->AttrAt(attr_idx); \ + return ApplyCallHelper::template Apply( \ + ctx, std::forward(pargs)..., arg); \ + } \ + } + +template +struct TypeTag {}; + +// used for converting a member function of the IRSchedule to be a free function +// with the first parameter is a pointer to the IRSchedule. +template +struct FreeFuncConverter; + +template +struct FreeFuncConverter { + static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } +}; + +template +struct FreeFuncConverter { + static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } +}; + +// used for formatting scheduling functions with variaous function signatures to be uniformed form +template +struct ApplyFuncImpl; + +template +struct ApplyFuncImpl { + static std::vector Apply(PackedStepContext* ctx) { + return ApplyCallHelper>::template Apply<0, 0, 0>(ctx); + } + + private: + template + struct ApplyCallHelper; + + // the signature of input parameters of a scheduling operation only can + // be one of IRSchedule, Expr or std::vector + template + struct ApplyCallHelper { + template + static std::vector Apply(PackedStepContext* ctx) { + static_assert(in_idx == 0, "IRSchedule* must be the first argument"); + IRSchedule* ir_schedule = ctx->ScheduleHandler(); + return ApplyCallHelper::template Apply(ctx, ir_schedule); + } + }; + + template + struct ApplyCallHelper { + template + static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + auto arg = ctx->InputAt(in_idx - 1); + return ApplyCallHelper::template Apply( + ctx, std::forward(pargs)..., arg); + } + }; + + template + struct ApplyCallHelper { + template + static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + auto arg = ctx->InputAt(in_idx - 1); + return ApplyCallHelper::template Apply( + ctx, std::forward(pargs)..., arg); + } + }; + + template + struct ApplyCallHelper&, Tail...> { + template + static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + auto arg = ctx->InputsAt(in_idx - 1); + return ApplyCallHelper::template Apply( + ctx, std::forward(pargs)..., arg); + } + }; + + CINN_SPECIALIZE_ApplyCallHelper(bool); + CINN_SPECIALIZE_ApplyCallHelper(int); + CINN_SPECIALIZE_ApplyCallHelper(float); + CINN_SPECIALIZE_ApplyCallHelper(const std::string&); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + CINN_SPECIALIZE_ApplyCallHelper(int64_t); + CINN_SPECIALIZE_ApplyCallHelper(double); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); + + template + struct ApplyReturnHelper; + + template + struct ApplyReturnHelper { + static std::vector Apply(Args... args) { + impl_fn(std::forward(args)...); + return {}; + } + }; + + template + struct ApplyReturnHelper { + static std::vector Apply(Args... args) { + auto ret = impl_fn(std::forward(args)...); + return {ret}; + } + }; + + template + struct ApplyReturnHelper> { + static std::vector Apply(Args... args) { return impl_fn(std::forward(args)...); } + }; + + // end: base template + template + struct ApplyCallHelper> { + template + static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { + static_assert(out_idx == 0, "Output is exported from return value"); + return ApplyReturnHelper::Apply(std::forward(pargs)...); + } + }; +}; + +#define APPLY_FUNC_UNIFORM(...) ::cinn::ir::ApplyFuncImpl::Apply +#define FREE_FUNCTION_CONVERTER(...) ::cinn::ir::FreeFuncConverter::Apply + +#define CINN_BUILD_STEP_KIND(TypeName) \ + static ::cinn::ir::StepKindInfo& __step_kind_registrar_##TypeName = \ + ::cinn::ir::StepKindRegistry::Global()->__REGISTER_OR_GET__(#TypeName) + +// register StepKindInfo for every type of scheduling operation +// clang-format off +CINN_BUILD_STEP_KIND(GetAllBlocks) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast (IRSchedule::*)() const>(&IRSchedule::GetAllBlocks)))); + +CINN_BUILD_STEP_KIND(GetChildBlocks) + .Inputs({"expr"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast (IRSchedule::*)(const Expr&) const>(&IRSchedule::GetChildBlocks)))); + +CINN_BUILD_STEP_KIND(GetLoops) + .Inputs({"block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast (IRSchedule::*)(const Expr&) const>(&IRSchedule::GetLoops)))); + +CINN_BUILD_STEP_KIND(GetLoopsWithName) + .Attrs({"block_name"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast (IRSchedule::*)(const std::string&) const>(&IRSchedule::GetLoops)))); + +CINN_BUILD_STEP_KIND(GetBlock) + .Attrs({"block_name"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast(&IRSchedule::GetBlock)))); + +CINN_BUILD_STEP_KIND(Split) + .Inputs({"loop", "factors"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast (IRSchedule::*)(const Expr&, const std::vector&)>(&IRSchedule::Split)))); + +CINN_BUILD_STEP_KIND(Fuse) + .Inputs({"loops"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Fuse)))); + +CINN_BUILD_STEP_KIND(FuseWithName) + .Attrs({"block_name", "loops_index"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Fuse)))); + +CINN_BUILD_STEP_KIND(FuseWithBlock) + .Inputs({"block"}) + .Attrs({"loops_index"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Fuse)))); + +CINN_BUILD_STEP_KIND(ComputeAt) + .Inputs({"block", "loop"}) + .Attrs({"keep_unit_loops"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ComputeAt))); + +CINN_BUILD_STEP_KIND(SimpleComputeAt) + .Inputs({"block", "loop"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SimpleComputeAt))); + +CINN_BUILD_STEP_KIND(ReverseComputeAt) + .Inputs({"block", "loop"}) + .Attrs({"keep_unit_loops"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ReverseComputeAt))); + +CINN_BUILD_STEP_KIND(GetRootBlock) + .Inputs({"expr"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::GetRootBlock))); + +CINN_BUILD_STEP_KIND(CacheRead) + .Inputs({"block"}) + .Attrs({"read_buffer_index", "memory_type"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::CacheRead))); + +CINN_BUILD_STEP_KIND(CacheWrite) + .Inputs({"block"}) + .Attrs({"write_buffer_index", "memory_type"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::CacheWrite))); + +CINN_BUILD_STEP_KIND(SyncThreads) + .Inputs({"ir_node"}) + .Attrs({"after_node"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SyncThreads))); + +CINN_BUILD_STEP_KIND(SetBuffer) + .Inputs({"block"}) + .Attrs({"memory_type", "fixed"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SetBuffer))); + +CINN_BUILD_STEP_KIND(Reorder) + .Inputs({"loops"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Reorder)))); + +CINN_BUILD_STEP_KIND(ReorderWithBlock) + .Inputs({"block"}) + .Attrs({"loops_index"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Reorder)))); + +CINN_BUILD_STEP_KIND(ReorderWithName) + .Attrs({"block_name", "loops_index"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( + static_cast&)>(&IRSchedule::Reorder)))); + +CINN_BUILD_STEP_KIND(Parallel) + .Inputs({"loop"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Parallel))); + +CINN_BUILD_STEP_KIND(Vectorize) + .Inputs({"loop"}) + .Attrs({"factor"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Vectorize))); + +CINN_BUILD_STEP_KIND(Unroll) + .Inputs({"loop"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unroll))); + +CINN_BUILD_STEP_KIND(ComputeInline) + .Inputs({"schedule_block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ComputeInline))); + +CINN_BUILD_STEP_KIND(ReverseComputeInline) + .Inputs({"schedule_block"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ReverseComputeInline))); + +CINN_BUILD_STEP_KIND(Bind) + .Inputs({"loop"}) + .Attrs({"thread_axis"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Bind))); + +CINN_BUILD_STEP_KIND(Rfactor) + .Inputs({"rf_loop"}) + .Attrs({"rf_axis"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Rfactor))); + +CINN_BUILD_STEP_KIND(MergeExprs) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::MergeExprs))); + +template void Annotate(IRSchedule* ir_sch, const Expr&, const std::string&, AttrType); +template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, int value) { + ir_sch->Annotate(block, key, value); +} +template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, bool value) { + ir_sch->Annotate(block, key, value); +} +template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, float value) { + ir_sch->Annotate(block, key, value); +} +void AnnotateStringAttr(IRSchedule* ir_sch, const Expr& block, const std::string& key, const std::string& value) { + ir_sch->Annotate(block, key, value); +} + +CINN_BUILD_STEP_KIND(AnnotateIntAttr) + .Inputs({"block"}) + .Attrs({"key", "value"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); + +CINN_BUILD_STEP_KIND(AnnotateBoolAttr) + .Inputs({"block"}) + .Attrs({"key", "value"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); + +CINN_BUILD_STEP_KIND(AnnotateFloatAttr) + .Inputs({"block"}) + .Attrs({"key", "value"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); + +CINN_BUILD_STEP_KIND(AnnotateStringAttr) + .Inputs({"block"}) + .Attrs({"key", "value"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(AnnotateStringAttr)); + +CINN_BUILD_STEP_KIND(Unannotate) + .Inputs({"block"}) + .Attrs({"key"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unannotate))); + +CINN_BUILD_STEP_KIND(FlattenLoops) + .Inputs({"loops"}) + .Attrs({"force_flat"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::FlattenLoops))); + +CINN_BUILD_STEP_KIND(SamplePerfectTile) + .Inputs({"loop"}) + .Attrs({"n", "max_innermost_factor", "decision"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SamplePerfectTile))); + +CINN_BUILD_STEP_KIND(TagPostSchedule) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::TagPostSchedule))); + +CINN_BUILD_STEP_KIND(SampleCategorical) + .Attrs({"candidates", "probs", "decision"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SampleCategorical))); +// clang-format on + +// ------ Following codes are about member function implement of the ScheduleDesc class +void AttrVariantToProto(const utils::Attribute& attr, proto::ScheduleDesc_Attr* attr_proto) { +#define SET_DESC_SINGLE_ITEM(index, built_type, proto_type, proto_field) \ + case index: \ + attr_proto->set_dtype(proto::ScheduleDesc_Attr_DataType_##proto_type); \ + attr_proto->set_##proto_field(absl::get(attr)); \ + break; + +#define SET_DESC_REPEATED_ITEM(index, built_type, proto_type, proto_field) \ + case index: { \ + attr_proto->set_dtype(proto::ScheduleDesc_Attr_DataType_##proto_type); \ + const auto& values = absl::get(attr); \ + attr_proto->mutable_##proto_field()->Reserve(values.size()); \ + *attr_proto->mutable_##proto_field() = {values.begin(), values.end()}; \ + break; \ + } + + switch (attr.index()) { + SET_DESC_SINGLE_ITEM(0, bool, BOOLEAN, b); + SET_DESC_SINGLE_ITEM(1, float, FLOAT, f); + SET_DESC_SINGLE_ITEM(2, int, INT, i); + SET_DESC_SINGLE_ITEM(3, std::string, STRING, s); + SET_DESC_REPEATED_ITEM(4, std::vector, BOOLEANS, bools); + SET_DESC_REPEATED_ITEM(5, std::vector, INTS, ints); + SET_DESC_REPEATED_ITEM(6, std::vector, FLOATS, floats); + SET_DESC_REPEATED_ITEM(7, std::vector, STRINGS, strings); + SET_DESC_SINGLE_ITEM(8, int64_t, LONG, l); + SET_DESC_SINGLE_ITEM(9, double, DOUBLE, d); + SET_DESC_REPEATED_ITEM(10, std::vector, LONGS, longs); + SET_DESC_REPEATED_ITEM(11, std::vector, DOUBLES, doubles); + default: + LOG(FATAL) << "Invalid index:" << attr.index(); + } + +#undef SET_DESC_SINGLE_ITEM +#undef SET_DESC_REPEATED_ITEM +} + +utils::Attribute AttrProtoToVariant(const proto::ScheduleDesc_Attr& attr) { + utils::Attribute value; +#define PARSE_DESC_SINGLE_ITEM(proto_type, proto_field, built_type) \ + case proto::ScheduleDesc_Attr_DataType_##proto_type: \ + value = built_type(attr.proto_field()); \ + break; + +#define PARSE_DESC_REPEATED_ITEM(proto_type, proto_field, built_type) \ + case proto::ScheduleDesc_Attr_DataType_##proto_type: \ + value = built_type({attr.proto_field().begin(), attr.proto_field().end()}); \ + break; + + switch (attr.dtype()) { + PARSE_DESC_SINGLE_ITEM(BOOLEAN, b, bool); + PARSE_DESC_SINGLE_ITEM(INT, i, int); + PARSE_DESC_SINGLE_ITEM(FLOAT, f, float); + PARSE_DESC_SINGLE_ITEM(STRING, s, std::string); + PARSE_DESC_REPEATED_ITEM(BOOLEANS, bools, std::vector); + PARSE_DESC_REPEATED_ITEM(INTS, ints, std::vector); + PARSE_DESC_REPEATED_ITEM(FLOATS, floats, std::vector); + PARSE_DESC_REPEATED_ITEM(STRINGS, strings, std::vector); + PARSE_DESC_SINGLE_ITEM(LONG, l, int64_t); + PARSE_DESC_SINGLE_ITEM(DOUBLE, d, double); + PARSE_DESC_REPEATED_ITEM(LONGS, longs, std::vector); + PARSE_DESC_REPEATED_ITEM(DOUBLES, doubles, std::vector); + default: + LOG(FATAL) << "Invalid type:" << attr.DebugString(); + } + +#undef PARSE_DESC_SINGLE_ITEM +#undef PARSE_DESC_REPEATED_ITEM + return value; +} + +// Expr hash functor, presents how to hash an Expr +struct ExprHash { + size_t operator()(const Expr& e) const { return std::hash()(e.ptr()); } +}; +// Expr equal functor, presents whether a Expr pair is equal +struct ExprEqual { + bool operator()(const Expr& lhs, const Expr& rhs) const { return lhs.get() == rhs.get(); } +}; + +void ScheduleDesc::Append(Step&& step) { steps_.emplace_back(std::move(step)); } + +void ScheduleDesc::Pop() { + if (!steps_.empty()) { + steps_.pop_back(); + } +} + +void ScheduleDesc::Replay(IRSchedule* schedule, bool without_post_schedule) const { + ReplayWithProto(this->ToProto(), schedule, without_post_schedule); +} + +proto::ScheduleDesc ScheduleDesc::ToProto() const { + // map each Expr to a formatted name (e1, e2, ...) + absl::flat_hash_map expr2name; + proto::ScheduleDesc desc_proto; + + for (auto&& step : steps_) { + auto* step_proto = desc_proto.add_steps(); + step_proto->set_type(step.type); + // inputs of a step must refer to Exprs resulted by preceding steps + for (auto&& param2exprs : step.inputs) { + const std::string& param_name = param2exprs.first; + auto* expr_desc = step_proto->add_inputs(); + expr_desc->set_parameter(param_name); + for (auto&& expr : param2exprs.second) { + auto expr_it = expr2name.find(expr); + CHECK(expr_it != expr2name.end()) << "Can't find expr of param_name: " << param_name; + expr_desc->add_arguments(expr_it->second); + } + } + + // each output Expr is represented by a formatted name, to be refered by suceeding steps + for (auto&& expr : step.outputs) { + std::string local_name = "e" + std::to_string(expr2name.size()); + expr2name.emplace(expr, local_name); + step_proto->add_outputs(expr2name.at(expr)); + } + + for (auto&& attr2value : step.attrs) { + auto* attr_proto = step_proto->add_attrs(); + const auto& attr_value = attr2value.second; + VLOG(5) << "Attr.index:" << attr_value.index(); + attr_proto->set_name(attr2value.first); + AttrVariantToProto(attr_value, attr_proto); + } + } + return desc_proto; +} + +std::vector ScheduleDesc::ReplayWithProto(const proto::ScheduleDesc& desc_proto, + IRSchedule* sch, + bool without_post_schedule) { + VLOG(4) << "proto::ScheduleDesc:\n" << desc_proto.DebugString(); + if (desc_proto.steps().empty()) { + LOG(WARNING) << "Input proto::ScheduleDesc is empty"; + return {}; + } + + // map a formatted name (e1, e2, ...) to an Expr + absl::flat_hash_map name2expr; + std::vector last_outputs; + + // resotre each scheduling step and apply to the new IRSchedule object + for (auto&& step_proto : desc_proto.steps()) { + VLOG(4) << "Replay step:\n" << step_proto.DebugString(); + ScheduleDesc::Step step; + step.type = step_proto.type(); + CHECK(!step.type.empty()) << "Name of StepKind is empty"; + if (without_post_schedule && step.type == "TagPostSchedule") { + break; + } + const StepKindInfo* step_kind = StepKindRegistry::Global()->Find(step.type); + CHECK(step_kind) << "Can't find StepKind:" << step.type; + + for (auto&& param2args : step_proto.inputs()) { + for (auto&& arg : param2args.arguments()) { + auto arg_it = name2expr.find(arg); + CHECK(arg_it != name2expr.end()) << "Cant't find argument:" << arg; + step.inputs[param2args.parameter()].emplace_back(arg_it->second); + } + } + for (auto&& attr : step_proto.attrs()) { + step.attrs[attr.name()] = AttrProtoToVariant(attr); + } + + PackedStepContext context(step, step_kind, sch); + step.outputs = step_kind->Apply(&context); + CHECK_EQ(step_proto.outputs().size(), step.outputs.size()) << "Output size not matched"; + for (size_t i = 0; i < step.outputs.size(); ++i) { + name2expr[step_proto.outputs(i)] = step.outputs.at(i); + } + last_outputs = std::move(step.outputs); + } + return last_outputs; +} + +ScheduleDesc ScheduleDesc::ForkAndUpdate(int step_idx, utils::Attribute decision, bool without_post_schedule) const { + int n_valid_step = 0; + if (!without_post_schedule) { + n_valid_step = steps_.size(); + } else { + for (const auto& step : steps_) { + if (step.type != "TagPostSchedule") { + ++n_valid_step; + } else { + break; + } + } + } + std::vector new_steps(steps_.begin(), steps_.begin() + n_valid_step); + new_steps[step_idx].attrs["decision"] = decision; + return ScheduleDesc(std::move(new_steps)); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/schedule_desc.h b/paddle/cinn/ir/schedule_desc.h new file mode 100644 index 0000000000000..43a1820cfe9e0 --- /dev/null +++ b/paddle/cinn/ir/schedule_desc.h @@ -0,0 +1,106 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/schedule_desc.pb.h" +#include "cinn/utils/registry.h" +#include "cinn/utils/type_defs.h" + +namespace cinn { +namespace ir { + +// A ScheduleDesc describe the scheduling process of an ir::ModuleExpr, it records +// all transform/getting operations executed by a corresponding ir::IRSchedule. +// A ScheduleDesc can be serialized to JSON format and saved to file. For deserializing, +// it can be re-applied to a new IRSchedule that is initialzied by a semantics-euqal +// original ir::ModuleExpr, and then achieves the same result. + +class IRSchedule; // forward declartion to avoid cross-reference +class ScheduleDesc { + public: + // each operation executed through IRSchedule is recorded as a step + struct Step { + std::string type; // step name + absl::flat_hash_map> inputs; + utils::AttributeMap attrs; + std::vector outputs; + Step() = default; + Step(std::string type_i, + absl::flat_hash_map> inputs_i, + utils::AttributeMap attrs_i, + std::vector outputs_i) + : type(type_i), inputs(inputs_i), attrs(attrs_i), outputs(outputs_i) {} + }; + + /** + * \brief Re-applied a scheduling process represented as a proto::ScheduleDesc to a new IRSchedule object. + * @param desc_proto The proto of the ScheduleDesc to be re-applied. + * @param sch The original IRSchedule to be replayed the description on. + * @param without_post_schedule Determine whether to delete the post schedules. + */ + static std::vector ReplayWithProto(const proto::ScheduleDesc& desc_proto, + IRSchedule* sch, + bool without_post_schedule = false); + + ScheduleDesc() = default; + + ScheduleDesc(const std::vector& steps) : steps_(steps) {} + + ScheduleDesc(std::vector&& steps) : steps_(steps) {} + + // Append a new step + void Append(Step&& step); + + // Pop the last step + void Pop(); + + /** + * \brief Replay this description to a new IRSchedule that is initialzied by a semantics-euqal original ModuleExpr. + * @param schedule The original IRSchedule to be replayed the description on. + * @param without_post_schedule Determine whether to delete the post schedules. + */ + void Replay(IRSchedule* schedule, bool without_post_schedule = false) const; + + // convert to a proto::ScheduleDesc object + proto::ScheduleDesc ToProto() const; + + // return detail string of a ScheduleDesc for debug; + std::string DebugString() const { return ToProto().DebugString(); } + + std::vector Steps() const { return steps_; } + + bool Empty() const { return steps_.empty(); } + + /** + * \brief Fork this ScheduleDesc and update a step of the new ScheduleDesc with a new decision. + * @param step_idx The index of the step to be update. + * @param decision The new decision. + * @param without_post_schedule Determine whether to delete the post schedules. + * @return The new ScheduleDesc. + */ + ScheduleDesc ForkAndUpdate(int step_idx, utils::Attribute decision, bool without_post_schedule) const; + + private: + std::vector steps_; // all operations are recorded in order. +}; + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/schedule_desc.proto b/paddle/cinn/ir/schedule_desc.proto new file mode 100644 index 0000000000000..829478cf22dd4 --- /dev/null +++ b/paddle/cinn/ir/schedule_desc.proto @@ -0,0 +1,67 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax ="proto3"; + +package cinn.ir.proto; + +message ScheduleDesc { + message Expr { + string parameter = 1; + repeated string arguments = 2; + }; + + // Attribute type and value + message Attr { + enum DataType { + BOOLEAN = 0; + INT = 1; + FLOAT = 2; + STRING = 3; + BOOLEANS = 4; + INTS = 5; + FLOATS = 6; + STRINGS = 7; + LONG = 8; + DOUBLE = 9; + LONGS = 10; + DOUBLES = 11; + }; + + string name = 1; + DataType dtype = 2; + bool b = 3; + int32 i = 4; + float f = 5; + string s = 6; + repeated bool bools = 7; + repeated int32 ints = 8; + repeated float floats = 9; + repeated string strings = 10; + int64 l = 11; + double d = 12; + repeated int64 longs = 13; + repeated double doubles = 14; + }; + + message Step { + string type = 1; + repeated Expr inputs = 2; + repeated string outputs = 3; + repeated Attr attrs = 4; + }; + + // scheduling operation sequence + repeated Step steps = 1; +}; diff --git a/paddle/cinn/ir/schedule_desc_test.cc b/paddle/cinn/ir/schedule_desc_test.cc new file mode 100644 index 0000000000000..171f1fbedc3f8 --- /dev/null +++ b/paddle/cinn/ir/schedule_desc_test.cc @@ -0,0 +1,809 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/schedule_desc.h" + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/common/context.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/lang/lower.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/utils/string.h" +#include "cinn/utils/type_defs.h" + +namespace cinn { +namespace ir { + +// Return lowerd ir AST for example functions used in this test +std::vector LowerCompute(const std::vector& shape, + const Target& target, + bool need_c = false, + const std::string& operation = "elementwise-copy") { + CHECK(shape.size() == 2 || shape.size() == 3) << "shape should be 2 or 3"; + std::vector domain; + for (auto i = 0; i < shape.size(); ++i) { + domain.emplace_back(shape[i]); + } + + Placeholder A("A", domain); + ir::Tensor B, C; + + if (operation == "elementwise-copy") { + if (domain.size() == 2) { + B = Compute( + domain, [&A](Var i, Var j) { return A(i, j); }, "B"); + C = Compute( + domain, [&B](Var i, Var j) { return B(i, j); }, "C"); + } else { + B = Compute( + domain, [&A](Var i, Var j, Var k) { return A(i, j, k); }, "B"); + C = Compute( + domain, [&B](Var i, Var j, Var k) { return B(i, j, k); }, "C"); + } + } + + if (operation == "elementwise-add_const") { + if (domain.size() == 2) { + B = Compute( + domain, [&A](Var i, Var j) { return A(i, j) * Expr(2.f); }, "B"); + C = Compute( + domain, [&B](Var i, Var j) { return B(i, j) + Expr(1.f); }, "C"); + } else { + B = Compute( + domain, [&A](Var i, Var j, Var k) { return A(i, j, k) * Expr(2.f); }, "B"); + C = Compute( + domain, [&B](Var i, Var j, Var k) { return B(i, j, k) + Expr(1.f); }, "C"); + } + } + + if (need_c) { + return cinn::lang::LowerVec("test_func", CreateStages({A, B, C}), {A, C}, {}, {}, nullptr, target, true); + } + + return cinn::lang::LowerVec("test_func", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); +} + +// Create a new IRSchedule with copied ir::LoweredFunc AST +IRSchedule MakeIRSchedule(const std::vector& lowered_funcs) { + std::vector exprs; + for (auto&& func : lowered_funcs) { + exprs.emplace_back(optim::IRCopy(func->body)); + } + return ir::IRSchedule(ir::ModuleExpr(exprs)); +} + +// Generate source code with transformed ModuleExpr +std::string SourceCodeGen(const ModuleExpr& module_expr, + const std::vector& lowered_funcs, + const Target& target) { + auto exprs = module_expr.GetExprs(); + CHECK_EQ(exprs.size(), lowered_funcs.size()) << "size of func is not euqal"; + std::vector updated_funcs = optim::IRCopy(lowered_funcs); + Module::Builder builder("test_module", target); + for (auto i = 0; i < lowered_funcs.size(); ++i) { + updated_funcs[i]->body = optim::IRCopy(exprs.at(i)); + builder.AddFunction(updated_funcs[i]); + } + auto module = builder.Build(); + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + return codegen.Compile(module, CodeGenC::OutputKind::CImpl); +} + +class TestScheduleDesc : public ::testing::Test { + public: + Target target = common::DefaultHostTarget(); + std::vector lowered_funcs; + ScheduleDesc trace; + void SetUp() override { Context::Global().ResetNameId(); } + + void CheckTracingOutputs(const std::vector& base, const ScheduleDesc& trace_desc) { + Context::Global().ResetNameId(); + ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs); + auto traced_outputs = ScheduleDesc::ReplayWithProto(trace_desc.ToProto(), &replay_sch); + ASSERT_EQ(base.size(), traced_outputs.size()); + for (auto i = 0; i < base.size(); ++i) { + ASSERT_EQ(utils::GetStreamCnt(base.at(i)), utils::GetStreamCnt(traced_outputs.at(i))); + } + } + + void CheckReplayResult(const ir::IRSchedule& ir_sch, const ScheduleDesc& trace_desc) { + Context::Global().ResetNameId(); + ir::IRSchedule replay_sch = MakeIRSchedule(lowered_funcs); + trace_desc.Replay(&replay_sch); + + // check the equality of module expr between original schedule + // and the schedule generated by replaying with tracing ScheduleDesc + auto lhs_exprs = ir_sch.GetModule().GetExprs(); + auto rhs_exprs = replay_sch.GetModule().GetExprs(); + ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); + for (auto i = 0; i < lhs_exprs.size(); ++i) { + ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), utils::GetStreamCnt(rhs_exprs.at(i))); + } + + // check the equality of source code between them + ASSERT_EQ(utils::Trim(SourceCodeGen(ir_sch.GetModule(), lowered_funcs, target)), + utils::Trim(SourceCodeGen(replay_sch.GetModule(), lowered_funcs, target))); + } +}; + +TEST_F(TestScheduleDesc, Append_Replay) { + lowered_funcs = LowerCompute({32, 32}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto fused = ir_sch.Fuse("B", {0, 1}); + trace.Append(ScheduleDesc::Step( + "FuseWithName", {}, {{"block_name", std::string("B")}, {"loops_index", std::vector({0, 1})}}, {fused})); + auto sample = ir_sch.SamplePerfectTile(fused, 2, 1, {4, -1}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({fused})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{4, -1}}}, + sample)); + auto splited = ir_sch.Split(fused, sample); + trace.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({fused})}, {"factors", sample}}, {}, splited)); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + fused = ir_sch.Fuse(loops); + trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused})); + sample = ir_sch.SamplePerfectTile(fused, 2, 1, {256, -1}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({fused})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{256, -1}}}, + sample)); + splited = ir_sch.Split(fused, sample); + trace.Append(ScheduleDesc::Step("Split", {{"loop", std::vector({fused})}, {"factors", sample}}, {}, splited)); + + // check the equality of results between the ir_sch and replaying of trace + CheckTracingOutputs(splited, trace); + CheckReplayResult(ir_sch, trace); + // check the equality of results between the ir_sch and replaying of its trace + CheckTracingOutputs(splited, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +// Test cases with `StepKind` prefix are to check the correctness of their StepKindInfo register +TEST_F(TestScheduleDesc, StepKind_GetAllBlocks) { + lowered_funcs = LowerCompute({32, 32}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto all_blocks = ir_sch.GetAllBlocks(); + trace.Append(ScheduleDesc::Step("GetAllBlocks", {}, {}, {all_blocks})); + CheckTracingOutputs(all_blocks, trace); + CheckTracingOutputs(all_blocks, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_GetChildBlocks) { + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto loops = ir_sch.GetLoops("C"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + ir_sch.ComputeAt(block_b, loops[1]); + trace.Append(ScheduleDesc::Step("ComputeAt", + {{"block", std::vector({block_b})}, {"loop", std::vector({loops[1]})}}, + {{"keep_unit_loops", false}}, + {})); + loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto root_block = ir_sch.GetRootBlock(loops[1]); + trace.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({loops[1]})}}, {}, {root_block})); + auto childblocks = ir_sch.GetChildBlocks(root_block); + trace.Append(ScheduleDesc::Step("GetChildBlocks", {{"expr", std::vector({root_block})}}, {}, childblocks)); + CheckTracingOutputs(childblocks, trace); + CheckTracingOutputs(childblocks, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_GetLoops) { + lowered_funcs = LowerCompute({32, 32}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto loops = ir_sch.GetLoops(block_b); + trace.Append(ScheduleDesc::Step("GetLoops", {{"block", std::vector({block_b})}}, {}, loops)); + CheckTracingOutputs(loops, trace); + CheckTracingOutputs(loops, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_GetLoopsWithName) { + lowered_funcs = LowerCompute({32, 32}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + CheckTracingOutputs(loops, trace); + CheckTracingOutputs(loops, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_GetBlock) { + lowered_funcs = LowerCompute({32, 32, 32}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + CheckTracingOutputs({block_b}, trace); + CheckTracingOutputs({block_b}, ir_sch.GetTraceDesc()); +} +// TODO: fix in future, as fix split var name, this case some problem. +/* +TEST_F(TestScheduleDesc, StepKind_Split) { + lowered_funcs = LowerCompute({32, 32, 32}, target); + ir::IRSchedule ir_sch_split_base = MakeIRSchedule(lowered_funcs); + ir::IRSchedule ir_sch_split = MakeIRSchedule(lowered_funcs); + ir::IRSchedule ir_sch_split_with_name = MakeIRSchedule(lowered_funcs); + + // test split with inputs of Expr + auto loops = ir_sch_split_base.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto sample = ir_sch_split_base.SamplePerfectTile(loops.front(), 2, 1, {4, -1}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops.front()})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{4, -1}}}, + sample)); + auto splited = ir_sch_split_base.Split(loops.front(), sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops.front()})}, {"factors", sample}}, {}, splited)); + CheckTracingOutputs(splited, trace); + CheckTracingOutputs(splited, ir_sch_split_base.GetTraceDesc()); + + // test split with inputs of int + loops = ir_sch_split.GetLoops("B"); + splited = ir_sch_split.Split(loops.front(), {4, -1}); + CheckTracingOutputs(splited, trace); + CheckTracingOutputs(splited, ir_sch_split.GetTraceDesc()); + + // test split with block name and inputs of int + splited = ir_sch_split_with_name.Split("B", 0, {4, -1}); + CheckTracingOutputs(splited, trace); + CheckTracingOutputs(splited, ir_sch_split_with_name.GetTraceDesc()); +} +*/ +TEST_F(TestScheduleDesc, StepKind_Fuse) { + lowered_funcs = LowerCompute({32, 32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto fused = ir_sch.Fuse(loops); + trace.Append(ScheduleDesc::Step("Fuse", {{"loops", loops}}, {}, {fused})); + CheckTracingOutputs({fused}, trace); + CheckTracingOutputs({fused}, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_FuseWithName) { + lowered_funcs = LowerCompute({32, 32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto fused = ir_sch.Fuse("B", {0, 1, 2}); + trace.Append(ScheduleDesc::Step( + "FuseWithName", {}, {{"block_name", std::string("B")}, {"loops_index", std::vector({0, 1, 2})}}, {fused})); + CheckTracingOutputs({fused}, trace); + CheckTracingOutputs({fused}, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_FuseWithBlock) { + lowered_funcs = LowerCompute({32, 32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto fused = ir_sch.Fuse(block_b, {0, 1, 2}); + trace.Append(ScheduleDesc::Step("FuseWithBlock", + {{"block", std::vector({block_b})}}, + {{"loops_index", std::vector({0, 1, 2})}}, + {fused})); + CheckTracingOutputs({fused}, trace); + CheckTracingOutputs({fused}, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ComputeAt) { + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto loops = ir_sch.GetLoops("C"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + ir_sch.ComputeAt(block_b, loops[1]); + trace.Append(ScheduleDesc::Step("ComputeAt", + {{"block", std::vector({block_b})}, {"loop", std::vector({loops[1]})}}, + {{"keep_unit_loops", false}}, + {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_SimpleComputeAt) { + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto loops = ir_sch.GetLoops("C"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + ir_sch.SimpleComputeAt(block_b, loops[2]); + trace.Append(ScheduleDesc::Step("SimpleComputeAt", + {{"block", std::vector({block_b})}, {"loop", std::vector({loops[2]})}}, + {{"keep_unit_loops", false}}, + {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ReverseComputeAt) { + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_c = ir_sch.GetBlock("C"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.ReverseComputeAt(block_c, loops[1]); + trace.Append(ScheduleDesc::Step("ReverseComputeAt", + {{"block", std::vector({block_c})}, {"loop", std::vector({loops[1]})}}, + {{"keep_unit_loops", false}}, + {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_GetRootBlock) { + lowered_funcs = LowerCompute({32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto root_b = ir_sch.GetRootBlock(loops[1]); + trace.Append(ScheduleDesc::Step("GetRootBlock", {{"expr", std::vector({loops[1]})}}, {}, {root_b})); + CheckTracingOutputs({root_b}, trace); + CheckTracingOutputs({root_b}, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_CacheRead) { + lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto a_cache = ir_sch.CacheRead(block_b, 0, "local"); + trace.Append(ScheduleDesc::Step("CacheRead", + {{"block", std::vector({block_b})}}, + {{"read_buffer_index", 0}, {"memory_type", std::string("local")}}, + {a_cache})); + CheckTracingOutputs({a_cache}, trace); + CheckTracingOutputs({a_cache}, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_CacheWrite) { + lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + trace.Append(ScheduleDesc::Step("CacheWrite", + {{"block", std::vector({block_b})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {b_cache})); + CheckTracingOutputs({b_cache}, trace); + CheckTracingOutputs({b_cache}, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_SyncThreads) { + lowered_funcs = LowerCompute({64, 32}, target, true, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + auto b_cache = ir_sch.CacheWrite(block_b, 0, "local"); + trace.Append(ScheduleDesc::Step("CacheWrite", + {{"block", std::vector({block_b})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {b_cache})); + auto block_c = ir_sch.GetBlock("C"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + auto c_cache = ir_sch.CacheWrite(block_c, 0, "local"); + trace.Append(ScheduleDesc::Step("CacheWrite", + {{"block", std::vector({block_c})}}, + {{"write_buffer_index", 0}, {"memory_type", std::string("local")}}, + {c_cache})); + block_c = ir_sch.GetBlock("C"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + ir_sch.SyncThreads(block_c, false); + trace.Append( + ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({block_c})}}, {{"after_node", false}}, {})); + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.SyncThreads(block_b); + trace.Append( + ScheduleDesc::Step("SyncThreads", {{"ir_node", std::vector({block_b})}}, {{"after_node", true}}, {})); + + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_SetBuffer) { + lowered_funcs = LowerCompute({32, 64}, target, false, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.SetBuffer(block_b, "shared", true); + trace.Append(ScheduleDesc::Step("SetBuffer", + {{"block", std::vector({block_b})}}, + {{"memory_type", std::string("shared")}, {"fixed", true}}, + {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Reorder) { + lowered_funcs = LowerCompute({32, 64, 12}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[0]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + sample)); + auto splited = ir_sch.Split(loops[0], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + + loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[2]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + sample)); + splited = ir_sch.Split(loops[2], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + + loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + Expr ret = ir_sch.Reorder({loops[4], loops[0]}); + trace.Append(ScheduleDesc::Step("Reorder", {{"loops", std::vector({loops[4], loops[0]})}}, {}, {ret})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ReorderWithBlock) { + lowered_funcs = LowerCompute({32, 32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[0]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + sample)); + auto splited = ir_sch.Split(loops[0], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + + loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[2]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + sample)); + splited = ir_sch.Split(loops[2], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + Expr ret = ir_sch.Reorder("B", {2, 3, 1, 4, 0}); + trace.Append(ScheduleDesc::Step("ReorderWithBlock", + {{"block", std::vector({block_b})}}, + {{"loops_index", std::vector({2, 3, 1, 4, 0})}}, + {ret})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ReorderWithName) { + lowered_funcs = LowerCompute({32, 32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto sample = ir_sch.SamplePerfectTile(loops[0], 2, 1, {-1, 4}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[0]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 4}}}, + sample)); + auto splited = ir_sch.Split(loops[0], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[0]})}, {"factors", sample}}, {}, splited)); + + loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + sample = ir_sch.SamplePerfectTile(loops[2], 2, 1, {-1, 2}); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[2]})}}, + {{"n", 2}, {"max_innermost_factor", 1}, {"decision", std::vector{-1, 2}}}, + sample)); + splited = ir_sch.Split(loops[2], sample); + trace.Append( + ScheduleDesc::Step("Split", {{"loop", std::vector({loops[2]})}, {"factors", sample}}, {}, splited)); + + Expr ret = ir_sch.Reorder("B", {4, 2, 3, 1, 0}); + trace.Append( + ScheduleDesc::Step("ReorderWithName", + {}, + {{"block_name", std::string("B")}, {"loops_index", std::vector({4, 2, 3, 1, 0})}}, + {ret})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Parallel) { + lowered_funcs = LowerCompute({32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.Parallel(loops[0]); + trace.Append(ScheduleDesc::Step("Parallel", {{"loop", std::vector({loops[0]})}}, {}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Vectorize) { + lowered_funcs = LowerCompute({32, 64}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.Vectorize(loops[1], 16); + trace.Append(ScheduleDesc::Step("Vectorize", {{"loop", std::vector({loops[1]})}}, {{"factor", 16}}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Unroll) { + lowered_funcs = LowerCompute({32, 2}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.Unroll(loops[1]); + trace.Append(ScheduleDesc::Step("Unroll", {{"loop", std::vector({loops[1]})}}, {}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ComputeInline) { + lowered_funcs = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.ComputeInline(block_b); + trace.Append(ScheduleDesc::Step("ComputeInline", {{"schedule_block", std::vector({block_b})}}, {}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) { + lowered_funcs = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + auto block_c = ir_sch.GetBlock("C"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c})); + ir_sch.ReverseComputeInline(block_c); + trace.Append(ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector({block_c})}}, {}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Bind) { + lowered_funcs = LowerCompute({32, 128}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + ir_sch.Bind(loops[0], "blockIdx.x"); + trace.Append(ScheduleDesc::Step( + "Bind", {{"loop", std::vector({loops[0]})}}, {{"thread_axis", std::string("blockIdx.x")}}, {})); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Rfactor) { + Expr M(32); + Expr N(2); + Expr K(16); + + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + Var k(16, "k0"); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + lowered_funcs = + cinn::lang::LowerVec("test_rfactor", CreateStages({A, B, C}), {A, B, C}, {}, {}, nullptr, target, true); + + cinn::common::Context::Global().ResetNameId(); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + cinn::common::Context::Global().ResetNameId(); + + auto loops = ir_sch.GetLoops("C"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("C")}}, loops)); + auto new_rf_tensor = ir_sch.Rfactor(loops[2], 0); + trace.Append( + ScheduleDesc::Step("Rfactor", {{"rf_loop", std::vector({loops[2]})}}, {{"rf_axis", 0}}, {new_rf_tensor})); + CheckTracingOutputs({new_rf_tensor}, trace); + CheckTracingOutputs({new_rf_tensor}, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_MergeExprs) { + auto funcs_0 = LowerCompute({32, 128}, target); + auto funcs_1 = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const"); + + ir::IRSchedule ir_sch = + ir::IRSchedule(ir::ModuleExpr({optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); + ir_sch.MergeExprs(); + trace.Append(ScheduleDesc::Step("MergeExprs", {}, {}, {})); + ir::IRSchedule replay_sch = + ir::IRSchedule(ir::ModuleExpr({optim::IRCopy(funcs_0[0]->body), optim::IRCopy(funcs_0[0]->body)})); + trace.Replay(&replay_sch); + + auto lhs_exprs = ir_sch.GetModule().GetExprs(); + auto rhs_exprs = replay_sch.GetModule().GetExprs(); + ASSERT_EQ(lhs_exprs.size(), rhs_exprs.size()); + for (auto i = 0; i < lhs_exprs.size(); ++i) { + ASSERT_EQ(utils::GetStreamCnt(lhs_exprs.at(i)), utils::GetStreamCnt(rhs_exprs.at(i))); + } +} + +TEST_F(TestScheduleDesc, StepKind_Annotate) { + lowered_funcs = LowerCompute({32, 128}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k1", int(64)); + trace.Append(ScheduleDesc::Step("AnnotateIntAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}, {"value", int(64)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k2", bool(true)); + trace.Append(ScheduleDesc::Step("AnnotateBoolAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}, {"value", bool(true)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k3", float(2.0)); + trace.Append(ScheduleDesc::Step("AnnotateFloatAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k3")}, {"value", float(2.0)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k4", std::string("v4")); + trace.Append(ScheduleDesc::Step("AnnotateStringAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k4")}, {"value", std::string("v4")}}, + {})); + + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_Unannotate) { + lowered_funcs = LowerCompute({32, 128}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k1", int(64)); + trace.Append(ScheduleDesc::Step("AnnotateIntAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}, {"value", int(64)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k2", bool(true)); + trace.Append(ScheduleDesc::Step("AnnotateBoolAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}, {"value", bool(true)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Unannotate(block_b, "k1"); + trace.Append( + ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k1")}}, {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Unannotate(block_b, "k2"); + trace.Append( + ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k2")}}, {})); + + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_SamplePerfectTile) { + Expr M(1024); + Var n(1, "n"); + + Placeholder A("A", {M}); + auto B = Compute( + {M}, [&](Expr i) { return A(i) + n; }, "B"); + lowered_funcs = + cinn::lang::LowerVec("test_sample_perfect_tile", CreateStages({A, B}), {A, B}, {}, {}, nullptr, target, true); + + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + auto loops = ir_sch.GetLoops("B"); + trace.Append(ScheduleDesc::Step("GetLoopsWithName", {}, {{"block_name", std::string("B")}}, loops)); + auto result = ir_sch.SamplePerfectTile(loops[0], 2, 64); + std::vector decision; + std::transform(result.begin(), result.end(), std::back_inserter(decision), [](Expr x) { return x.as_int32(); }); + trace.Append(ScheduleDesc::Step("SamplePerfectTile", + {{"loop", std::vector({loops[0]})}}, + {{"n", 2}, {"max_innermost_factor", 64}, {"decision", decision}}, + result)); + CheckTracingOutputs(result, trace); + CheckTracingOutputs(result, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +TEST_F(TestScheduleDesc, StepKind_SampleCategorical) { + lowered_funcs = LowerCompute({32, 32, 64}, target, true); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + Expr ret = ir_sch.SampleCategorical({1, 2, 3}, {1.0, 2.0, 3.0}); + std::vector decision = {ret.as_int32()}; + trace.Append(ScheduleDesc::Step("SampleCategorical", + {}, + {{"candidates", std::vector({1, 2, 3})}, + {"probs", std::vector({1.0, 2.0, 3.0})}, + {"decision", decision}}, + {ret})); + CheckTracingOutputs({ret}, trace); + CheckTracingOutputs({ret}, ir_sch.GetTraceDesc()); + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc new file mode 100755 index 0000000000000..f0e53231fd33e --- /dev/null +++ b/paddle/cinn/ir/tensor.cc @@ -0,0 +1,590 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/tensor.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/common/arithmatic.h" +#include "cinn/common/axis.h" +#include "cinn/common/cas.h" +#include "cinn/common/common.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/ir/operation.h" +#include "cinn/lang/compute.h" +#include "cinn/poly/isl_utils.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace ir { + +Tensor _Tensor_::Make(const std::string &name, + Type dtype, + const std::vector &shape, + const std::vector &domain, + FunctionRef fn, + const std::vector &reduce_axis) { + CHECK(!name.empty()) << "Tensor name is set empty"; + auto n = make_shared<_Tensor_>(); + n->name = name; + n->shape = shape; + n->domain = domain; + n->reduce_axis = reduce_axis; + n->set_type(dtype); + n->operation = fn; + n->InitAxis(); + + return Tensor(n); +} + +size_t Tensor::ndims() const { return operator->()->shape.size(); } + +std::set _Tensor_::GetDependTensorNames() const { + std::set names; + + auto add_depend_tensors_from_expr = [&](Expr expr) { + auto tensors = + CollectIRNodes(expr, [&](const Expr *x) { return x->as_tensor() && x->as_tensor()->name != this->name; }); + for (auto &e : tensors) { + names.insert(e.as_tensor()->name); + } + }; + + if (is_compute_node()) { + add_depend_tensors_from_expr(body()); + } else if (is_call_node()) { + add_depend_tensors_from_expr(body()); + } else if (is_extern_call_node()) { + add_depend_tensors_from_expr(body()); + } else if (is_placeholder_node()) { + return names; + } else { + CINN_NOT_IMPLEMENTED + } + + return names; +} + +Expr Tensor::operator()(const std::vector &indices) const { + CHECK(!self()->is_tuple()) << "should extract a specific value from the tuple and operate on that instead"; + auto *node = operator->(); + + CHECK_EQ(indices.size(), ndims()) << "number of indices not match the dimension"; + + return Load::Make(*this, indices); +} + +Expr _Tensor_::inline_expanded(const std::vector &indices) { + CHECK(is_compute_node()); + return get_compute_op()->producer_fn(indices); +} + +const char *_Tensor_::operation_type() const { + if (!operation.defined()) return ""; + return operation->as()->func_type(); +} + +bool _Tensor_::is_compute_node() const { return std::strcmp(operation_type(), ir::ComputeOp::__func_type__) == 0; } +bool _Tensor_::is_placeholder_node() const { + return std::strcmp(operation_type(), ir::PlaceholderOp::__func_type__) == 0; +} +bool _Tensor_::is_call_node() const { return std::strcmp(operation_type(), ir::CallOp::__func_type__) == 0; } +bool _Tensor_::is_extern_call_node() const { + if (std::strcmp(operation_type(), ir::CallOp::__func_type__) == 0) { + auto *op = operation->as(); + auto *call = op->call_expr.As(); + if (call) { + return call->is_extern_call(); + } + } + return false; +} +bool _Tensor_::is_buffer_shared_node() const { + return std::strcmp(operation_type(), ir::BufferShareOp::__func_type__) == 0; +} + +bool _Tensor_::is_preceding_view_node() const { + return std::strcmp(operation_type(), ir::PrecedingViewOp::__func_type__) == 0; +} + +ComputeOp *_Tensor_::get_compute_op() const { + if (!is_compute_node()) return nullptr; + return operation->as(); +} + +PlaceholderOp *_Tensor_::get_placeholder_op() const { + if (!is_placeholder_node()) return nullptr; + return operation->as(); +} + +void _Tensor_::InitAxis() const { + // CHECK(!domain_without_reduce_axis().empty()); + axis_ = common::GenDefaultAxis(domain_without_reduce_axis().size()); +} + +bool _Tensor_::has_expression() const { + return (!is_placeholder_node()) && (!is_tuple_get()) && (!is_buffer_shared_node()); +} + +isl::set _Tensor_::GenerateIslDomain() const { + // include the reduce axis. + std::vector dims; + + if (has_expression()) { + if (axis_.empty()) InitAxis(); + auto domain = domain_with_reduce_axis(); + CHECK_EQ(axis_with_reduce().size(), domain.size()); + auto _axis_with_reduce = axis_with_reduce(); + for (int i = 0; i < domain.size(); i++) { + auto dim = domain[i]; + if (dim.is_constant()) { + dims.emplace_back(_axis_with_reduce[i]->name, 0, dim.as_int32() - 1); + } else { + dims.emplace_back(_axis_with_reduce[i]->name, Expr(0), Sub::Make(dim, common::make_const(1))); + } + } + } + + poly::Domain isl_domain(Context::isl_ctx(), name, dims); + VLOG(1) << "name:" << this->name << ", domain: " << isl_domain.__str__(); + return isl_domain.to_isl(); +} + +std::vector _Tensor_::expr_fields() { + std::vector res; + const char *func_type = operation->as()->func_type(); + if (operation.defined()) { + if (is_compute_node()) { + auto *op = operation->as(); + for (auto &expr : op->body) res.push_back(&expr); + } else if (is_placeholder_node()) { + auto *op = operation->as(); + } else if (is_call_node()) { + auto *op = operation->as(); + for (auto &expr : op->read_args()) res.push_back(&expr); + } else if (is_buffer_shared_node()) { + } else { + CINN_NOT_IMPLEMENTED + } + } + + for (auto &e : shape) { + res.push_back(&e); + } + for (auto &e : domain) { + res.push_back(&e); + } + return res; +} + +std::vector _Tensor_::expr_fields() const { + std::vector res; + const char *func_type = operation->as()->func_type(); + if (operation.defined()) { + if (is_compute_node()) { + auto *op = operation->as(); + for (auto &expr : op->body) res.push_back(&expr); + } else if (is_placeholder_node()) { + auto *op = operation->as(); + } else if (is_call_node()) { + auto *op = operation->as(); + for (auto &expr : op->read_args()) res.push_back(&expr); + } else if (is_buffer_shared_node()) { + } else { + LOG(ERROR) << "func_type: " << func_type; + CINN_NOT_IMPLEMENTED + } + } + + for (auto &e : shape) { + res.push_back(&e); + } + for (auto &e : domain) { + res.push_back(&e); + } + + return res; +} + +_Tensor_::~_Tensor_() {} + +Expr _Tensor_::body() const { + if (is_placeholder_node()) return Expr(); + if (is_buffer_shared_node()) return Expr(); + if (is_compute_node()) return operation->as()->body.front(); + if (is_call_node()) return operation->as()->call_expr; + CINN_NOT_IMPLEMENTED; +} + +Expr *_Tensor_::mutable_body() { + if (is_placeholder_node()) return nullptr; + if (is_buffer_shared_node()) return nullptr; + if (is_compute_node()) return &operation->as()->body.front(); + if (is_call_node()) return &operation->as()->call_expr; + CINN_NOT_IMPLEMENTED +} + +ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target) const { + CHECK(contains_reduce_axis()) << "InitReduction only works on a reduce tensor"; + // return if already rexists. + std::string init_reduce_tensor_name = GenReduceInitTensorNameOf(name); + if (stages->Lookup(init_reduce_tensor_name)) return stages[this]->LookupCtrlDepend(init_reduce_tensor_name); + + // create a new init tensor. + auto init_tensor = lang::Compute( + domain, [=](const std::vector &axis) { return GetReduceInitVal(); }, init_reduce_tensor_name); + stages->InsertLazily(init_tensor); + std::string this_transform = isl_map_to_str(stages[this]->transform().get()); + isl::ctx this_ctx = stages[this]->transform().ctx(); + isl::map temp_transform(this_ctx, this_transform); + int reduce_axis_num = this->reduce_axis.size(); + auto dim_out_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_out); + auto dim_in_size = isl_map_dim(stages[this]->transform().get(), isl_dim_in); + auto dim_in_names = poly::isl_get_dim_names(stages[this]->transform(), isl_dim_in); + std::vector reduce_axis_input = stages[this]->origin_reduce_axis_names(); + auto origin_domain = stages[this]->domain(); + auto reduce_axis_output = poly::GetRelatedOutputAxies(temp_transform, origin_domain, reduce_axis_input); + std::set reduce_axis_output_set; + for (auto &i : reduce_axis_output) { + reduce_axis_output_set.insert(i); + } + int compute_at_axis = -1; + for (auto &i : dim_out_names) { + if (reduce_axis_output_set.count(i) == 0) { + compute_at_axis++; + } else { + break; + } + } + + temp_transform = poly::RemoveAxiesByOutputNames(temp_transform, origin_domain, reduce_axis_output); + + //! When the first axis is not reduce axis, do ComputeAt. + if (compute_at_axis >= 0) { + stages[init_tensor]->ComputeAt2(stages[this], compute_at_axis); + init_tensor->new_indices = this->new_indices; + stages[this]->CtrlDepend(init_tensor); + stages[init_tensor]->ShareBufferWith(stages[this]); + init_tensor->shape = shape; + return init_tensor; + } + //! When reduce axies are reordered to front, ComputeAt is illegal. + //! So we just copy transform and forloopInfo. + isl_map_set_tuple_name(temp_transform.get(), isl_dim_in, init_reduce_tensor_name.c_str()); + isl_map_set_tuple_name(temp_transform.get(), isl_dim_out, init_reduce_tensor_name.c_str()); + stages[init_tensor]->SetTransform(temp_transform); + auto init_dim_out_names = poly::isl_get_dim_names(temp_transform, isl_dim_out); + std::map temp_forloop_info = stages[this]->forloop_infos(); + std::map init_forloop_info; + for (auto &i : temp_forloop_info) { + for (int j = 0; j < init_dim_out_names.size(); j++) { + if (i.first < 0) continue; + int new_i = poly::isl_get_original_axes_from_optimized_level(stages[this]->transformed_domain().get(), i.first); + if (dim_out_names[new_i] == init_dim_out_names[j]) { + stages[init_tensor]->AddForloopInfo(j, i.second); + } + } + } + init_tensor->new_indices = this->new_indices; + stages[this]->CtrlDepend(init_tensor); + stages[init_tensor]->ShareBufferWith(stages[this]); + init_tensor->shape = shape; + return init_tensor; +} + +ir::Tensor _Tensor_::GetInitTensor(poly::StageMap stages, const Target &target) const { + return InitReduction(stages, target); +} + +Expr _Tensor_::tensor_store_expanded_body() { + CHECK(!is_placeholder_node()) << "placeholder should not expand store"; + + Expr final_body = body(); + if (shape.empty()) return final_body; + + std::vector g_axis = common::GenDefaultAxisAsExpr(shape.size()); + if (!new_indices.empty()) { + g_axis = new_indices; + } + + auto *reduce_node = body().As(); + if (reduce_node) { + final_body = reduce_node->body; + switch (reduce_node->reduce_type) { + case ir::Reduce::kSum: + final_body = Tensor(this)(g_axis) + final_body; + break; + case ir::Reduce::kMul: + final_body = Tensor(this)(g_axis) * final_body; + break; + case ir::Reduce::kMax: + final_body = Max::Make(Tensor(this)(g_axis), final_body); + break; + case ir::Reduce::kMin: + final_body = Min::Make(Tensor(this)(g_axis), final_body); + break; + case ir::Reduce::kAll: + final_body = Tensor(this)(g_axis) && final_body; + break; + case ir::Reduce::kAny: + final_body = Tensor(this)(g_axis) || final_body; + break; + default: + CINN_NOT_IMPLEMENTED + } + } + + if (is_tuple()) return final_body; + + return ir::Store::Make(Expr(Buffer(this)), final_body, g_axis); +} + +void _Tensor_::Bind(lang::Buffer &buffer) { + // CHECK(!inlined()) << "Inlined tensor should bing buffer"; + CHECK(!buffer->type().is_void()); + if (this->buffer.defined()) { + // remove the old buffer + if (this->buffer == buffer.buffer()) return; + this->buffer->Unbind(this); + } + // Extract the tensors thouse has binded to this buffer. + buffer_depended_tensor_names_ = buffer.buffer()->binded_tensor_names(); + + buffer.buffer()->BindTo(this); + CHECK(!buffer->binded_tensor_names().empty()); + this->buffer = buffer.buffer(); + CHECK(this->buffer.defined()); +} + +void _Tensor_::Bind(const Buffer &buffer) { + lang::Buffer buf(buffer); + Bind(buf); +} + +void _Tensor_::WithBuffer(const Type &type) { + Type buf_type = type.is_void() ? type_ : type; + lang::Buffer buf(buf_type); + buf->target = common::DefaultHostTarget(); + Bind(buf); +} + +void _Tensor_::WithBuffer(const std::string &memory_type, const std::string &buffer_name, const Type &type) { + Type buf_type = type.is_void() ? type_ : type; + if (this->buffer.defined()) { + this->buffer->dtype = buf_type; + this->buffer->name = buffer_name; + if (memory_type == "shared") { + this->buffer->memory_type = MemoryType::GPUShared; + } else if (memory_type == "local") { + this->buffer->memory_type = MemoryType::GPULocal; + } else if (memory_type == "global") { + this->buffer->memory_type = MemoryType::Heap; + } else { + LOG(FATAL) << "Not supported memory type " << memory_type; + } + } else { + lang::Buffer buf(buf_type, buffer_name); + buf->target = common::DefaultHostTarget(); + Bind(buf); + + if (memory_type == "shared") { + buf->memory_type = MemoryType::GPUShared; + } else if (memory_type == "local") { + buf->memory_type = MemoryType::GPULocal; + } else if (memory_type == "global") { + buf->memory_type = MemoryType::Heap; + } else { + LOG(FATAL) << "Not supported memory type " << memory_type; + } + } +} + +bool _Tensor_::HasSameShapeWith(const Tensor &other) const { + if (shape.size() != other->shape.size()) return false; + + for (int i = 0; i < shape.size(); i++) { + Expr dim0 = common::AutoSimplify(shape[i]); + Expr dim1 = common::AutoSimplify(other->shape[i]); + + if (dim0 != dim1) return false; + } + return true; +} + +Tensor _Tensor_::TupleGet(int offset) const { + CHECK(is_tuple()); + auto *call = body().As(); + CHECK_LT(offset, call->write_args.size()); + auto tensor = call->write_args[offset].as_tensor_ref(); + tensor->WithBuffer(); + return tensor; +} + +bool _Tensor_::is_tuple() const { + if (!has_expression()) return false; + auto *call = body().As(); + if (call && call->is_extern_call() && !call->write_args.empty()) return true; + return false; +} + +std::vector _Tensor_::domain_with_reduce_axis() const { + if (reduce_axis.empty()) return domain; + auto res = domain; + for (const Var &axis : reduce_axis) { + CHECK(axis->upper_bound.type().is_int(32)) << axis->upper_bound; + res.push_back(axis->upper_bound); + } + return res; +} + +bool operator<(const Tensor &a, const Tensor &b) { return a->name < b->name; } + +Tensor::Tensor(const std::string &name, + Type dtype, + const std::vector &shape, + const std::vector &domain, + FunctionRef fn, + const std::vector &reduce_axis) + : IrNodeRef(_Tensor_::Make(name, dtype, shape, domain, fn, reduce_axis).self()) {} + +bool _Tensor_::is_tuple_get() const { + return is_call_node() && operation.defined() && + operation->as()->func_type() == ir::CallOp::__func_type__ && + operation->as()->is_tuple_get; +} + +bool _Tensor_::IsDependOnStatement(absl::string_view statement) { + if (!is_compute_node()) { + return false; + } + + auto depend_tensors = DependingTensorNames(); + for (const auto &x : depend_tensors) { + if (x == statement) return true; + } + return false; +} + +std::set _Tensor_::DependingTensorNames() { + std::set res; + if (body().defined()) { + auto depend_tensors = ir::CollectIRNodes(body(), [](const Expr *x) -> bool { return x->as_tensor(); }); + for (const auto &x : depend_tensors) { + if (x.get() != this) { + res.insert(x.as_tensor()->name); + } + } + } + return res; +} + +const std::vector &_Tensor_::axis() const { + CHECK_EQ(axis_.size(), domain_without_reduce_axis().size()); + return axis_; +} + +std::vector _Tensor_::axis_with_reduce() const { + auto axis = axis_; + axis.insert(axis.end(), reduce_axis.begin(), reduce_axis.end()); + return axis; +} + +bool _Tensor_::Uses(const Tensor &other) const { + auto loads = ir::CollectIRNodes(body(), [&](const Expr *x) { + auto *loadn = x->As(); + if (!loadn) return false; + return loadn->tensor.as_tensor()->name == other->name; + }); + return !loads.empty(); +} + +ir::Tensor _Tensor_::Reshape(const std::vector &shape, poly::StageMap stages) const { + CHECK(!stages[this]->inlined()); + auto op = BufferShareOp::Make(); + auto n = make_shared<_Tensor_>(); + auto selft = Tensor(const_cast(this)); + + { + Expr this_num_elements = Expr(1); + for (auto &e : this->shape) this_num_elements = this_num_elements * e; + + Expr num_elements = Expr(1); + for (auto &e : shape) num_elements = num_elements * e; + + CHECK(MathIsZero(this_num_elements - num_elements)) << "number of elements mismatch"; + } + + n->name = Context::Global().NewName(name + "_reshape"); + n->shape = shape; + n->domain = shape; + n->set_type(type()); + n->operation = op; + n->InitAxis(); + + auto t = Tensor(n); + stages->InsertLazily(t); + + stages[n]->ShareBufferWith(stages[this]); + stages[n]->CtrlDepend(selft); + return t; +} + +ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, poly::StageMap stages) const { + auto t = ir::Tensor(const_cast(this)); + auto copied = Compute( + domain, + [=](const std::vector &axis) { return t(axis); }, + Context::Global().NewName(this->name + "_copied")); + stages->InsertLazily(copied); + auto res = copied->Reshape(shape, stages); + stages->InsertLazily(res); + return res; +} + +Shared CreateStage(Tensor tensor) { + auto isl_domain = tensor->GenerateIslDomain(); + return poly::Stage::New(isl_domain, tensor->body(), tensor.self()); +} + +std::string GenReduceInitTensorNameOf(const std::string &tensor_name) { return tensor_name + "__reduce_init"; } + +bool _Tensor_::is_reduce_sum() const { + if (!contains_reduce_axis()) return false; + return body().As() && body().As()->reduce_type == ir::Reduce::ReduceType::kSum; +} +bool _Tensor_::is_reduce_mul() const { + if (!contains_reduce_axis()) return false; + return body().As() && body().As()->reduce_type == ir::Reduce::ReduceType::kMul; +} + +Expr _Tensor_::GetReduceInitVal() const { + CHECK(is_reduce_tensor()); + return body().As()->init; +} + +bool _Tensor_::IsReduceInited(poly::StageMap stages) const { return stages->Lookup(GenReduceInitTensorNameOf(name)); } + +void _Tensor_::Verify() const { + CHECK(!shape.empty()); + CHECK(!domain.empty()); + CHECK(!name.empty()) << "Name of tensor should be set"; +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/tensor.h b/paddle/cinn/ir/tensor.h new file mode 100644 index 0000000000000..437fe62e6d31c --- /dev/null +++ b/paddle/cinn/ir/tensor.h @@ -0,0 +1,342 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cinn/common/graph_utils.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/function_base.h" +#include "cinn/lang/buffer.h" +#include "cinn/poly/stage.h" + +namespace cinn { + +namespace ir { +class Tensor; +} // namespace ir + +namespace lang { +template +struct Placeholder; + +void InitReduceTensor(poly::StageMap stages, + const ir::Tensor& tensor, + const Target& target = common::DefaultHostTarget()); +} // namespace lang + +namespace ir { +namespace detail { +constexpr bool LE(int a, int b) { return a <= b; } +constexpr bool GE(int a, int b) { return a >= b; } + +} // namespace detail + +class _Tensor_; +class Tensor; + +class Tensor : public ir::IrNodeRef { + public: + Tensor() = default; + explicit Tensor(ir::IrNode* n) : IrNodeRef(n) {} + Tensor(const std::string& name, + Type dtype, + const std::vector& shape, + const std::vector& domain, + FunctionRef fn, + const std::vector& reduce_axis = {}); + + //! Get number of dimensions. + size_t ndims() const; + + /** + * Take elements from the tensor. + * This take one or multiple expressions as indices. + * + * usage: + * + * Tensor A; + * A(i,j) get the [i][j] element. + */ + // @{ + Expr operator()(const Expr& a) const { return operator()(std::vector({a})); } + template + inline typename std::enable_if::type operator()(Args&&... args) const { + return operator()({std::forward(args)...}); + } + // @} + + /** + * Take elements from the tensor. + * @param indices The indices. + * @return The result expression representing a tensor read. + */ + Expr operator()(const std::vector& indices) const; + + friend bool operator<(const Tensor& a, const Tensor& b); + + _Tensor_* self() { return operator->(); } + const _Tensor_* self() const { return operator->(); } + + inline const _Tensor_* operator->() const { return As<_Tensor_>(); } + inline _Tensor_* operator->() { return As<_Tensor_>(); } + + //! Cast to an Expr. + inline operator Expr() const { return Expr(get()); } +}; + +/** + * \brief Generate the name of the reduce init tensor of \p tensor. + * This is used for retrieving the corresponding reduction-init tensor from a stage map by name. + */ +std::string GenReduceInitTensorNameOf(const std::string& tensor_name); + +class ComputeOp; +class PlaceholderOp; +struct ReadCacheRelation; +struct WriteCacheRelation; + +/** + * _Tensor_ holds the content of a Tensor. + * + * NOTE(All) Some rules: + * + * 1. a _Tensor_ is a node in SSA, so every tensor's name should be unique, + * 2. never try to change a tensor's name, that will cause chaos. + */ +class _Tensor_ : public ExprNode<_Tensor_> { + public: + //! Shape of this tensor(buffer). + std::vector shape; + //! The domain of each axis(without reduce_axis) + // TODO(Superjomn) support ISL domain. + std::vector domain; + + std::vector reduce_axis; + //! The operation that generates Tensor. + FunctionRef operation; + //! Name of this tensor. + std::string name; + //! The bound buffer, for each tensor if it is not inline. + Buffer buffer; + //! Normal axis. + mutable std::vector axis_; + + std::vector new_indices{}; + std::vector domain_with_reduce_axis() const; + const std::vector& domain_without_reduce_axis() const { return domain; } + + //! Generate a tensor from a function. + static Tensor Make(const std::string& name, + Type dtype, + const std::vector& shape, + const std::vector& domain, + FunctionRef fn, + const std::vector& reduce_axis = {}); + + void Verify() const override; + + bool IsReduceInited(poly::StageMap stages) const; + + //! Tell whether this tensor represents a tuple (consists of one or multiple tensors as output of a extern Call). + bool is_tuple() const; + bool is_tuple_get() const; + + Tensor TupleGet(int offset) const; + + /** + * Get the names of the dependency(read or write) tensors. + * e.g. A[i] = C[i]*2 + D[i], A's dependency tensors are {C,D} + */ + std::set GetDependTensorNames() const; + + /** + * \brief Tell whether this tensor's computation relays on a specific statement. + * @param statement The name of a statement(equivalent to the id of tensor). + * @return A boolean. + */ + bool IsDependOnStatement(absl::string_view statement); + + /** + * Get the names of the tensors thouse this tensor depends on. + */ + std::set DependingTensorNames(); + + /** + * Get a new tensor with the \p shape, but the underlying buffer shared. + * NOTE the tensor to Reshape should not be an inlined computation. + */ + ir::Tensor Reshape(const std::vector& shape, poly::StageMap stages) const; + + /** + * Get a new tensor with the \p shape with a newly allocated buffer. + * NOTE the tensor to Reshape should not be an inlined computation. + */ + ir::Tensor ReshapeCopied(const std::vector& shape, poly::StageMap stages) const; + + /** + * Tell whether this tensor has same shape with \p other. + */ + bool HasSameShapeWith(const Tensor& other) const; + + //! Operation related. + // @{ + bool is_compute_node() const; + bool is_placeholder_node() const; + bool is_call_node() const; + bool is_extern_call_node() const; + bool is_preceding_view_node() const; + bool is_buffer_shared_node() const; + const char* operation_type() const; + ComputeOp* get_compute_op() const; + PlaceholderOp* get_placeholder_op() const; + // @} + + //! The expression generate this tensor, will be empty if it is a PlaceHolder. + Expr body() const; + Expr* mutable_body(); + //! Get the expression with `store(tensor)` inserted into the body. + Expr tensor_store_expanded_body(); + + Expr inline_expanded(const std::vector& indices); + + //! Tell whether contain a reduce axis. + bool contains_reduce_axis() const { return !reduce_axis.empty(); } + bool is_reduce_tensor() const { return contains_reduce_axis(); } + bool is_reduce_sum() const; + bool is_reduce_mul() const; + //! Get the initial value of a reduce tensor. + Expr GetReduceInitVal() const; + + std::vector expr_fields() override; + std::vector expr_fields() const override; + + /** + * The normal axis without reducing ones. + */ + const std::vector& axis() const; + + /** + * The axis with the reduce ones. + */ + std::vector axis_with_reduce() const; + + /** + * Get the tensors thouse depend on the same buffer belong to this tensor. + */ + const std::set& buffer_depended_tensor_names() const { return buffer_depended_tensor_names_; } + + static const IrNodeTy _node_type_ = IrNodeTy::_Tensor_; + + _Tensor_() : ExprNode<_Tensor_>(Float(32)) {} + + bool has_expression() const; + + ~_Tensor_(); + + /** + * Tell if this tensor uses other tensors in the body. + */ + bool Uses(const ir::Tensor& other) const; + + //! Bind to a buffer, will persist data to the buffer in runtime. + void Bind(lang::Buffer& buffer); // NOLINT + void Bind(const Buffer& buffer); + void UnBind(lang::Buffer& buffer); // NOLINT + + //! Create a buffer belong to this tensor. + void WithBuffer(const Type& type = Void()); + void WithBuffer(const std::string& memory_type, const std::string& buffer_name = "", const Type& type = Void()); + Tensor GetInitTensor(poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; + + private: + //! Initialize the axis field after the shape field is assigned. + void InitAxis() const; + + isl::set GenerateIslDomain() const; + + /** + * Create the initialization tensor. + * @param stages The stages. + * @param init_val The initial value. + * @return The initializing tensor. + */ + ir::Tensor InitReduction(poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const; + + //! The names of the tensors depend the same buffer and should schedule before this. + std::set buffer_depended_tensor_names_; + + friend Shared CreateStage(Tensor tensor); + + friend void lang::InitReduceTensor(poly::StageMap stages, const ir::Tensor& tensor, const Target& target); +}; + +Shared CreateStage(Tensor tensor); + +class _Operation_; +class Operation : public FunctionRef { + public: + Operation() = default; + explicit Operation(IrNode* n) : FunctionRef(n) {} + + inline const _Operation_* operator->() const { return reinterpret_cast<_Operation_*>(get()); } + inline _Operation_* operator->() { return reinterpret_cast<_Operation_*>(get()); } + + //! Get the i-th output of the operation. + // Tensor output(size_t i) const; + + std::string name; +}; + +class _Operation_ : public ir::FunctionBase { + public: + //! Optional name of the operation. + std::string name; + //! Optional tag of the operation. + std::string tag; + //! Additional attributes of the operation. + std::map attrs; + + const std::string& func_name() const final { return name; } + + void Verify() const override {} + + //! The function type. + virtual const char* func_type() const = 0; +}; + +} // namespace ir +} // namespace cinn + +namespace std { + +template <> +struct hash { + inline size_t operator()(const cinn::ir::Tensor& x) { + // We treat the tensor's name as the unique identifier. + return std::hash()(x->name); + } +}; + +} // namespace std diff --git a/paddle/cinn/ir/tensor_test.cc b/paddle/cinn/ir/tensor_test.cc new file mode 100755 index 0000000000000..54c46bfa7028b --- /dev/null +++ b/paddle/cinn/ir/tensor_test.cc @@ -0,0 +1,211 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/ir/tensor.h" + +#include + +#include "cinn/backends/codegen_c.h" +#include "cinn/backends/llvm/execution_engine.h" +#include "cinn/cinn.h" +#include "cinn/common/test_helper.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/lower.h" +#include "cinn/lang/packed_func.h" +#include "cinn/lang/placeholder.h" + +namespace cinn { +namespace ir { +using utils::GetStreamCnt; +using utils::Trim; + +TEST(Tensor, inlined) { + Expr M(100), N(20); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // C is inlined + Tensor C = lang::Compute( + {M, N}, [=](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + + Tensor D = lang::Compute( + {M, N}, [=](Var i, Var j) -> Expr { return C(i, j) * 2.f + 1.f; }, "D"); + + auto stages = CreateStages({D}); + stages[C]->ComputeInline(); + + auto func = lang::Lower("func_C", stages, {A, B, D}); + std::cout << "output: \n" << func << std::endl; + auto out = GetStreamCnt(func); + EXPECT_EQ(Trim(out), Trim(R"ROC( +function func_C (_A, _B, _D) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + D[i, j] = (1.00000000f + ((2.00000000f * A[i, j]) + (2.00000000f * B[i, j]))) + } + } +} +)ROC")); +} + +TEST(Tensor, IsDependOnStatement) { + Expr N(100); + + Placeholder X("X", {N}); + auto t = Compute( + {N}, [&](Var i) -> Expr { return X(i); }, "t"); + + ASSERT_TRUE(t->IsDependOnStatement("X")); + ASSERT_FALSE(t->IsDependOnStatement("XXX")); +} + +TEST(Tensor, Reshape) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(100); + Placeholder A("A", {M, N}); + + auto stages = CreateStages({A}); + + auto A1 = A->Reshape({Expr(10), Expr(10), Expr(100)}, stages); + auto B = Compute( + A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, "B"); + + stages->InsertLazily(B); + + auto func = lang::Lower("fn", stages, {A, B}); + + ir::Module::Builder builder("some_modue", common::DefaultHostTarget()); + builder.AddFunction(func); + + backends::CodeGenC codegenc(common::DefaultHostTarget()); + codegenc.SetInlineBuiltinCodes(false); + auto source = codegenc.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + LOG(INFO) << "source:\n" << source; + + auto target_source = R"ROC( +#include +#include + +void fn(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_malloc((void*)(0), _B); + const float* A_reshape = ((const float*)(_A->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 10; i += 1) { + for (int32_t j = 0; j < 10; j += 1) { + for (int32_t k = 0; k < 100; k += 1) { + B[((1000 * i) + ((100 * j) + k))] = (2.00000000f * A_reshape[((1000 * i) + ((100 * j) + k))]); + }; + }; + }; + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + + ASSERT_EQ(Trim(target_source), Trim(source)); +} + +TEST(Tensor, ReshapeCopied) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(100); + Placeholder A("A", {M, N}); + + auto stages = CreateStages({A}); + + auto A1 = A->ReshapeCopied({Expr(10), Expr(10), Expr(100)}, stages); + auto B = Compute( + A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; }, "B"); + + stages->InsertLazily(B); + + ir::Module::Builder builder("some_modue", common::DefaultHostTarget()); + auto func = lang::Lower("fn", stages, {A, B}, {}, {}, &builder); + + backends::CodeGenC codegenc(common::DefaultHostTarget()); + codegenc.SetInlineBuiltinCodes(false); + auto source = codegenc.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + LOG(INFO) << "source:\n" << source; + + auto target_source = R"ROC( +#include +#include + +void fn(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _A_copied_reshape = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 10, 10, 100 }, 32/*align*/); + cinn_buffer_malloc((void*)(0), _B); + cinn_buffer_malloc((void*)(0), _A_copied_reshape); + const float* A = ((const float*)(_A->memory)); + float* A_copied = ((float*)(_A_copied_reshape->memory)); + const float* A_copied_reshape = ((const float*)(_A_copied_reshape->memory)); + float* B = ((float*)(_B->memory)); + for (int32_t i = 0; i < 100; i += 1) { + for (int32_t j = 0; j < 100; j += 1) { + A_copied[((100 * i) + j)] = A[((100 * i) + j)]; + }; + }; + for (int32_t i = 0; i < 10; i += 1) { + for (int32_t j = 0; j < 10; j += 1) { + for (int32_t k = 0; k < 100; k += 1) { + B[((1000 * i) + ((100 * j) + k))] = (2.00000000f * A_copied_reshape[((1000 * i) + ((100 * j) + k))]); + }; + }; + }; + cinn_buffer_free((void*)(0), _A_copied_reshape); + cinn_buffer_free((void*)(0), _B); +} +)ROC"; + + ASSERT_EQ(Trim(target_source), Trim(source)); +} + +TEST(Tensor, reduce) { + Placeholder A("A", {Expr(10)}); + Var reduce_axis(Expr(10), "reduce_k"); + { + auto C = Compute( + A->shape, + [=](const std::vector& axis) { return lang::ReduceSum(A(reduce_axis) + 1.f, {reduce_axis}); }, + "C"); + ASSERT_TRUE(C->has_expression()); + ASSERT_TRUE(C->is_reduce_sum()); + ASSERT_FALSE(C->is_reduce_mul()); + } + + { + auto C = Compute( + A->shape, + [=](const std::vector& axis) { return lang::ReduceMul(A(reduce_axis) + 1.f, {reduce_axis}); }, + "C"); + ASSERT_TRUE(C->has_expression()); + ASSERT_TRUE(C->is_reduce_mul()); + ASSERT_FALSE(C->is_reduce_sum()); + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/lang/CMakeLists.txt b/paddle/cinn/lang/CMakeLists.txt new file mode 100644 index 0000000000000..9a9c86a63e141 --- /dev/null +++ b/paddle/cinn/lang/CMakeLists.txt @@ -0,0 +1,17 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + buffer.cc + compute.cc + placeholder.cc + lower.cc + builtin.cc + lower_impl.cc + packed_func.cc + ) + +cc_test(test_compute SRCS compute_test.cc DEPS cinncore) +cc_test(test_placeholder SRCS placeholder_test.cc DEPS cinncore) +cc_test(test_lower SRCS lower_test.cc DEPS cinncore) +cc_test(test_lower_impl SRCS lower_impl_test.cc DEPS cinncore) +cc_test(test_cinn_packed_func SRCS packed_func_test.cc DEPS cinncore) diff --git a/paddle/cinn/lang/README.md b/paddle/cinn/lang/README.md new file mode 100644 index 0000000000000..ebbb2ca579f64 --- /dev/null +++ b/paddle/cinn/lang/README.md @@ -0,0 +1,93 @@ +# Design of CINN/DSL +This module is a simple DSL defined in CINN project. +The DSL module aims to represent the overall computation in a hardware indenpendent way. + +## Concepts +### Object +All the mutable elements in CINN are `Object`. +### Shared +The `Shared` objects are reference-count-self-contained container, which is similar to the `std::shared_ptr`. + +One can pass a `Shared` object by passing a pointer and the consumer object should store it in a local `Shared` member variable. + +## Tensor + +The input or the temporary ouptut node. + +Every `Compute` will output a Tensor, the tensor can be sliced. + + + +### PlaceHolder + +The special tensor that represents a input slot. + +```c++ +PlaceHolder A("A", {M, N}); +PlaceHolder B("B", {M, N}); +``` + +## Operation + +The Operation is the operation on tensors, including + +- placeholder +- compute +- bound inference + +```c++ +Tensor C = Compute({M,N}/*output shape*/, [&](Var i, Var j) { + Var k; + return ReduceSum(A[i,k] * B[k,j], {k}); +}); +``` + +### Bound inference + +The PlaceHolder should define a shape. + +```c++ +Var M(Int(32)); +Var N(Int(32)); + +PlaceHolder A({M, N}); + +Var i,j; +Expr tmp = A[i][j] + 1; // i \in {0, M}; j \in {0, N} +``` + +To simplify the implementation, we use ISL to generate code for basic snippets. + +## Schedule + +The schedule will + +1. determine the order of computation, by topological sorting the computational graph composed of tensors. +2. transforming the computations + +### order schedule + +1. Topological sort the tensors +2. for each tensor, generate the code it needs. + +## Some examples +A matrix multiplication + +```c++ +// Declare some iterator variables. +Var i, j, k; +Placeholder A({M, K}), B({K, N}); + +Tensor C = Compute({M, N}/*output shape*/, + [](Var i, Var j) { + return ReduceSum(A(i,k) * B(k, j), k); + }, "C"); +Tensor D = Compute({M, N}, [](Var i, Var j) { + return Map(C(i,j) + 1); +}); + +Schedule s = CreateSchedule(C); +auto func = Build(s, [A, B, C], target=target, name="matmul"); + +func(a, b, c); +``` diff --git a/paddle/cinn/lang/buffer.cc b/paddle/cinn/lang/buffer.cc new file mode 100644 index 0000000000000..182d8c4b4c5a9 --- /dev/null +++ b/paddle/cinn/lang/buffer.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/buffer.h" + +#include "cinn/ir/buffer.h" + +namespace cinn { +namespace lang { + +using ir::_Buffer_; + +Buffer::Buffer(Type type, const std::string& name) { + buffer_ = _Buffer_::Make(); + buffer_->dtype = type; + buffer_->set_type(type_of()); + buffer_->elem_offset = Expr(0); + if (!name.empty()) { + buffer_->name = name; + } + buffer_->target = common::DefaultHostTarget(); +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/buffer.h b/paddle/cinn/lang/buffer.h new file mode 100644 index 0000000000000..bcb4f5a602e74 --- /dev/null +++ b/paddle/cinn/lang/buffer.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/ir/buffer.h" + +namespace cinn { +namespace lang { + +/** + * This is a DSL wrapper for ir::Buffer. + */ +class Buffer { + public: + explicit Buffer(Type type, const std::string& name = ""); + explicit Buffer(const ir::Buffer& x) : buffer_(x) {} + + ir::_Buffer_* operator->() { return buffer_.As(); } + const ir::_Buffer_* operator->() const { return buffer_.As(); } + + ir::_Buffer_* self() { return buffer_.As(); } + + ir::Buffer buffer() const { return buffer_; } + + private: + ir::Buffer buffer_; +}; + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/builtin.cc b/paddle/cinn/lang/builtin.cc new file mode 100644 index 0000000000000..266f704a76576 --- /dev/null +++ b/paddle/cinn/lang/builtin.cc @@ -0,0 +1,262 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/builtin.h" + +#include +#include +#include + +#include "cinn/cinn.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir.h" +#include "cinn/lang/buffer.h" + +namespace cinn { +namespace lang { + +using cinn::common::bfloat16; +using cinn::common::float16; + +Expr logic_and(const std::vector& conds) { + CHECK(!conds.empty()); + auto start = ir::And::Make(conds[0], conds[1]); + for (int i = 2; i < conds.size(); i++) { + start = ir::And::Make(start, conds[i]); + } + return start; +} + +Expr logic_or(const std::vector& conds) { + CHECK(!conds.empty()); + auto start = ir::Or::Make(conds[0], conds[1]); + for (int i = 2; i < conds.size(); i++) { + start = ir::Or::Make(start, conds[i]); + } + return start; +} + +//! extern call op +#define EXTERN_CALL_IMP(name__, target__) \ + Expr name__(Expr e) { return ir::Call::Make(e->type(), #target__, {e}, {}, ir::CallType::Extern); } + +#define EXTERN_CALL_IMP_NO_VEC(name__, target__) \ + Expr name__(Expr e) { \ + return ir::Call::Make( \ + e->type(), #target__, {e}, {}, ir::CallType::Extern, ir::FunctionRef(), 0, {{"vectorizable", false}}); \ + } + +EXTERN_CALL_IMP(Exp, exp); +EXTERN_CALL_IMP_NO_VEC(Erf, erf); +EXTERN_CALL_IMP(Sqrt, sqrt); +EXTERN_CALL_IMP(Rsqrt, rsqrt); +EXTERN_CALL_IMP(Log, log); +EXTERN_CALL_IMP(Log2, log2); +EXTERN_CALL_IMP(Log10, log10); +EXTERN_CALL_IMP(Floor, floor); +EXTERN_CALL_IMP(Ceil, ceil); +EXTERN_CALL_IMP(Round, round); +EXTERN_CALL_IMP(Trunc, trunc); +EXTERN_CALL_IMP(Cos, cos); +EXTERN_CALL_IMP(Sin, sin); +EXTERN_CALL_IMP(Cosh, cosh); +EXTERN_CALL_IMP(Tan, tan); +EXTERN_CALL_IMP(Tanh, tanh); +EXTERN_CALL_IMP(Sinh, sinh); +EXTERN_CALL_IMP_NO_VEC(Acos, acos); +EXTERN_CALL_IMP_NO_VEC(Acosh, acosh); +EXTERN_CALL_IMP_NO_VEC(Asin, asin); +EXTERN_CALL_IMP_NO_VEC(Asinh, asinh); +EXTERN_CALL_IMP_NO_VEC(Atan, atan); +EXTERN_CALL_IMP_NO_VEC(Atanh, atanh); +EXTERN_CALL_IMP(Cbrt, cbrt); +EXTERN_CALL_IMP(Clz, clz); +EXTERN_CALL_IMP(Popc, popc); + +#undef EXTERN_CALL_IMP +#undef EXTERN_CALL_IMP_NO_VEC + +#define EXTERN_BINARY_CALL_IMP(name__, target__) \ + Expr name__(Expr a, Expr b) { \ + CHECK_EQ(a.type(), b.type()) << #name__ << "'s inputs type not equal, where a:" << a.type() \ + << " but b:" << b.type(); \ + return ir::Call::Make(a->type(), #target__, {a, b}, {}, ir::CallType::Extern); \ + } + +EXTERN_BINARY_CALL_IMP(Remainder, mod) +EXTERN_BINARY_CALL_IMP(LogicalRightShift, logical_right_shift) +EXTERN_BINARY_CALL_IMP(Pow, pow) +EXTERN_BINARY_CALL_IMP(Mod, mod) + +#undef EXTERN_BINARY_CALL_IMP + +Expr Zero(const Type& type) { return ir::Zero(type); } + +Expr One(const Type& type) { return ir::One(type); } + +Expr FloorDivide(Expr a, Expr b) { + CHECK_EQ(a.type(), b.type()) << "FloorDivide's inputs type not equal, where a:" << a.type() << " but b:" << b.type(); + if (a.type().is_float()) { + return Floor(a / b); + } else if (a.type().is_uint()) { + return a / b; + } else { + auto div = a / b; + auto mod = a % b; + auto ret = ir::Select::Make( + ir::EQ::Make(mod, common::make_const(a.type(), 0)), div, div - common::make_const(a.type(), 1)); + return ir::Select::Make((a > 0 && b > 0) || (a < 0 && b < 0), div, ret); + } +} + +Expr min_value(const Type& type) { + CHECK_EQ(type.lanes(), 1); +#define FOR_CASE(type__) \ + if (type == type_of()) { \ + return Expr(static_cast(std::numeric_limits::lowest())); \ + } + FOR_CASE(int8_t) + FOR_CASE(int16_t) + FOR_CASE(int32_t) + FOR_CASE(int64_t) + FOR_CASE(uint8_t) + FOR_CASE(uint16_t) + FOR_CASE(uint32_t) + FOR_CASE(uint64_t) + FOR_CASE(bfloat16) + FOR_CASE(float16) + FOR_CASE(float) + FOR_CASE(double) +#undef FOR_CASE + return Expr(); +} + +Expr max_value(const Type& type) { + CHECK_EQ(type.lanes(), 1); + +#define FOR_CASE(type__) \ + if (type == type_of()) { \ + return Expr(static_cast(std::numeric_limits::max())); \ + } + FOR_CASE(int8_t) + FOR_CASE(int16_t) + FOR_CASE(int32_t) + FOR_CASE(int64_t) + FOR_CASE(uint8_t) + FOR_CASE(uint16_t) + FOR_CASE(uint32_t) + FOR_CASE(uint64_t) + FOR_CASE(bfloat16) + FOR_CASE(float16) + FOR_CASE(float) + FOR_CASE(double) +#undef FOR_CASE + + CINN_NOT_IMPLEMENTED + return Expr(); +} + +Expr Epsilon(const Type& type) { + CHECK_EQ(type.lanes(), 1); + +#define FOR_CASE(type__) \ + if (type == type_of()) { \ + return Expr(static_cast(std::numeric_limits::epsilon())); \ + } + FOR_CASE(int8_t) + FOR_CASE(int16_t) + FOR_CASE(int32_t) + FOR_CASE(int64_t) + FOR_CASE(uint8_t) + FOR_CASE(uint16_t) + FOR_CASE(uint32_t) + FOR_CASE(uint64_t) + FOR_CASE(bfloat16) + FOR_CASE(float16) + FOR_CASE(float) + FOR_CASE(double) +#undef FOR_CASE + + CINN_NOT_IMPLEMENTED + return Expr(); +} + +Expr Abs(Expr e) { + Type type = e->type(); + Type bool_type = Bool(type.lanes()); + if (type.is_uint()) { + return e; + } else if (type.is_int() || type.is_float()) { + auto node = e.As(); + if (node) { + return make_const(type, std::abs(node->value)); + } + return ir::Select::Make(e > Zero(e->type()), e, -e); + } else { + LOG(FATAL) << "Abs Not support data type " << type; + } + return e; +} + +Expr IsNan(Expr e) { + Type type = e->type(); + if (type.is_int() || type.is_uint()) { + return common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + auto* node = e.As(); + if (node) { + return common::make_bool(std::isnan(node->value), type.lanes()); + } + return CallExtern("isnan", {e}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << type << "is not supported for isnan op."; + return e; + } +} + +Expr Infinity(const Type& type) { + CHECK_EQ(type.lanes(), 1U); + if (type.is_float()) { + if (type.bits() == 64) { + return make_const(type, std::numeric_limits::infinity()); + } else if (type.bits() == 32) { + return make_const(type, std::numeric_limits::infinity()); + } else if (type.bits() == 16) { + return make_const(type, std::numeric_limits::infinity()); + } + } + LOG(FATAL) << "Cannot decide infinity for type " << type; + return Expr(); +} + +Expr IsInf(Expr e) { + Type type = e->type(); + if (type.is_int() || type.is_uint()) { + return common::make_bool(false, type.lanes()); + } else if (type.is_float()) { + auto* node = e.As(); + if (node) { + return common::make_bool(std::isinf(node->value), type.lanes()); + } + return CallExtern("isinf", {e}, {{"vectorizable", false}}); + } else { + LOG(FATAL) << type << "is not supported for isinf op."; + return e; + } +} + +Expr IsFinite(Expr e) { return !IsInf(e) && !IsNan(e); } + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/builtin.h b/paddle/cinn/lang/builtin.h new file mode 100644 index 0000000000000..763461b697bc2 --- /dev/null +++ b/paddle/cinn/lang/builtin.h @@ -0,0 +1,173 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_operators.h" + +namespace cinn { +namespace lang { + +//! Get the ALL of the conditions. +Expr logic_and(const std::vector& conds); +Expr logic_or(const std::vector& conds); + +Expr Zero(const Type& type); +Expr One(const Type& type); +Expr min_value(const Type& type); +Expr max_value(const Type& type); +Expr Epsilon(const Type& type); + +//! extern call op +#define EXTERN_CALL_DCL(name__) Expr name__(Expr e); + +EXTERN_CALL_DCL(Exp); +EXTERN_CALL_DCL(Erf); +EXTERN_CALL_DCL(Sqrt); +EXTERN_CALL_DCL(Rsqrt); +EXTERN_CALL_DCL(Log); +EXTERN_CALL_DCL(Log2); +EXTERN_CALL_DCL(Log10); +EXTERN_CALL_DCL(Floor); +EXTERN_CALL_DCL(Ceil); +EXTERN_CALL_DCL(Round); +EXTERN_CALL_DCL(Trunc); +EXTERN_CALL_DCL(Cos); +EXTERN_CALL_DCL(Cosh); +EXTERN_CALL_DCL(Tan); +EXTERN_CALL_DCL(Sin); +EXTERN_CALL_DCL(Sinh); +EXTERN_CALL_DCL(Acos); +EXTERN_CALL_DCL(Acosh); +EXTERN_CALL_DCL(Asin); +EXTERN_CALL_DCL(Asinh); +EXTERN_CALL_DCL(Atan); +EXTERN_CALL_DCL(Atanh); +EXTERN_CALL_DCL(Tanh); +EXTERN_CALL_DCL(Cbrt); +EXTERN_CALL_DCL(Clz); +EXTERN_CALL_DCL(Popc); + +#undef EXTERN_CALL_DCL + +//! extern call binary op +#define EXTERN_BINARY_CALL_DCL(name__) Expr name__(Expr a, Expr b); + +EXTERN_BINARY_CALL_DCL(FloorDivide); +EXTERN_BINARY_CALL_DCL(Remainder); +EXTERN_BINARY_CALL_DCL(Mod); +EXTERN_BINARY_CALL_DCL(LogicalRightShift); +EXTERN_BINARY_CALL_DCL(Pow); + +#undef EXTERN_BINARY_CALL_DCL + +inline Expr Sigmoid(Expr e) { + auto one = One(e->type()); + return one / (one + Exp(-e)); +} + +inline Expr Sign(Expr e) { + auto zero = Zero(e->type()); + auto one = One(e->type()); + auto neg_one = ir::Cast::Make(e->type(), Expr(-1)); + auto ret0 = ir::Select::Make(ir::EQ::Make(e, zero), zero, e); + auto ret1 = ir::Select::Make(e > zero, one, ret0); + auto ret2 = ir::Select::Make(e < zero, neg_one, ret1); + return ret2; +} + +Expr Abs(Expr e); + +inline Expr Negative(Expr e) { return -e; } +inline Expr Identity(Expr e) { return e; } +inline Expr LogicalNot(Expr e) { return !e; } +inline Expr BitwiseNot(Expr e) { return ~e; } +inline Expr BitwiseAnd(Expr a, Expr b) { return a & b; } +inline Expr BitwiseOr(Expr a, Expr b) { return a | b; } +inline Expr BitwiseXor(Expr a, Expr b) { return a ^ b; } +inline Expr LeftShift(Expr a, Expr b) { return a << b; } +inline Expr RightShift(Expr a, Expr b) { return a >> b; } + +inline Expr Relu(Expr e, double threshold = 0.0) { + return ir::Max::Make(e, ir::Cast::Make(e->type(), Expr(threshold))); +} + +inline Expr Relu6(Expr e, double threshold = 0.0) { + return ir::Min::Make(ir::Max::Make(e, ir::Cast::Make(e->type(), Expr(threshold))), + ir::Cast::Make(e->type(), Expr(6.0))); +} + +inline Expr LeakyRelu(Expr e, double alpha) { + auto zero = Zero(e->type()); + return ir::Select::Make(e > zero, e, e * ir::Cast::Make(e->type(), Expr(alpha))); +} + +inline Expr LeakyRelu(Expr e, Expr alpha) { + auto zero = Zero(e->type()); + return ir::Select::Make(e > zero, e, e * alpha); +} + +inline Expr ReduceSum(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = Zero(e->type()); + } + return ir::Reduce::Make(ir::Reduce::kSum, initial, e, reduce_axis); +} + +inline Expr ReduceMul(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = One(e->type()); + } + return ir::Reduce::Make(ir::Reduce::kMul, initial, e, reduce_axis); +} + +inline Expr ReduceMax(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = min_value(e.type()); + } + return ir::Reduce::Make(ir::Reduce::kMax, initial, e, reduce_axis); +} +inline Expr ReduceMin(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = max_value(e.type()); + } + return ir::Reduce::Make(ir::Reduce::kMin, initial, e, reduce_axis); +} +inline Expr ReduceAll(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = Expr(true); + } + return ir::Reduce::Make(ir::Reduce::kAll, initial, e, reduce_axis); +} +inline Expr ReduceAny(Expr e, const std::vector& reduce_axis, Expr initial = Expr()) { + if (!initial.defined()) { + initial = Expr(false); + } + return ir::Reduce::Make(ir::Reduce::kAny, initial, e, reduce_axis); +} + +Expr IsNan(Expr e); + +Expr Infinity(const Type& type); + +Expr IsInf(Expr e); + +Expr IsFinite(Expr e); + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/compute.cc b/paddle/cinn/lang/compute.cc new file mode 100644 index 0000000000000..ac2e83ede44cf --- /dev/null +++ b/paddle/cinn/lang/compute.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/compute.h" + +#include "cinn/backends/extern_func_protos.h" +#include "cinn/common/common.h" +#include "cinn/ir/operation.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/poly/dim.h" +#include "cinn/poly/domain.h" +#include "cinn/poly/stage.h" +#include "cinn/runtime/use_extern_funcs.h" + +namespace cinn { +namespace lang { + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + // CHECK_EQ(axis.size(), 0); + return fn(); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 1); + return fn(axis[0]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 2); + return fn(axis[0], axis[1]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 3); + return fn(axis[0], axis[1], axis[2]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 4); + return fn(axis[0], axis[1], axis[2], axis[3]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 5); + return fn(axis[0], axis[1], axis[2], axis[3], axis[4]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape) { + return Compute( + domain, + [fn](const std::vector &axis) -> Expr { + CHECK_EQ(axis.size(), 6); + return fn(axis[0], axis[1], axis[2], axis[3], axis[4], axis[5]); + }, + name, + shape); +} + +ir::Tensor Compute(const std::vector &domain, + std::function &)> fn, + const std::string &name, + const std::vector &shape) { + auto axises = common::GenDefaultAxis(domain.size()); + std::vector _axis; + for (auto &x : axises) _axis.push_back(x); + Expr fn_body = fn(_axis); + + std::vector reduce_axis; + if (fn_body.defined() && fn_body.As()) { + auto &fn_reduce_axis = fn_body.As()->reduce_axis; + reduce_axis.insert(std::begin(reduce_axis), fn_reduce_axis.begin(), fn_reduce_axis.end()); + } + + // When the fn_body is a CallExtern, a tensor will return directly. + if (fn_body.as_tensor()) { + return fn_body.as_tensor_ref(); + } + + // shape is the buffer's shape. + std::vector domain_without_reduce_axis; + std::vector shape_simplified; + + // construct the shape. + for (auto dim : domain) { + auto copied = dim; + optim::Simplify(&copied); + domain_without_reduce_axis.push_back(copied); + } + + for (auto dim : shape) { + auto copied = dim; + optim::Simplify(&copied); + shape_simplified.push_back(copied); + } + + auto real_shape = shape_simplified.empty() ? domain_without_reduce_axis : shape_simplified; + + // The body returns void, that means no buffer is needed. + if (fn_body.type() == Void()) real_shape.clear(); + + auto unique_name = name.empty() ? Context::Global().NewName("tensor") : name; + + // check reduce_axis not include the reserved axis name + for (auto &ra : reduce_axis) { + CHECK(!common::IsAxisNameReserved(ra->name)) << "reduce axis [" << ra->name << "]'s name is reserved"; + } + + VLOG(3) << "tensor " << name << "'s domain is : " << domain_without_reduce_axis; + + auto op = ir::ComputeOp::Make(unique_name, fn, real_shape, domain_without_reduce_axis, reduce_axis); + auto tensor = ir::Tensor(unique_name, fn_body.type(), real_shape, domain_without_reduce_axis, op, reduce_axis); + return tensor; +} + +std::vector CallLowered(const std::string &func_name, + const std::vector &args, + const std::vector &return_types) { + auto call = ir::Call::Make(Void(), func_name, args, {}, ir::CallType::CINN, ir::FunctionRef(), 0); + std::vector new_tensors; + for (int i = 0; i < return_types.size(); i++) { + auto &return_type = return_types[i]; + auto call_op = ir::CallOp::Make(func_name, call); + auto new_tensor = ir::Tensor(return_type.name, return_type.type, return_type.dims, {Expr(1)}, call_op); + // Append write tensors in the tail. + call.As()->write_args.push_back(new_tensor); + new_tensor->set_type(return_type.type); + new_tensor->WithBuffer(); + new_tensors.push_back(new_tensor); + } + + return new_tensors; +} + +Expr CallExtern(const std::string &func_name, + const std::vector &args, + const std::map &attrs) { + auto *proto = backends::ExternFunctionProtoRegistry::Global().Lookup(func_name); + CHECK(proto) << "No extern function prototype " << func_name << " found\n" + << "existing records are:\n" + << backends::ExternFunctionProtoRegistry::Global().debug_string(); + + auto call = ir::Call::Make(proto->ret_type, func_name, args, {}, ir::CallType::Extern, ir::FunctionRef(), 0, attrs); + std::vector mutable_args; + // Call a function with multiple outputs. + if (proto->ret_type.is_void()) { + for (int i = 0; i < proto->mutable_arg_types.size(); i++) { + auto shape = proto->shape_inference(args, i); + auto op = ir::CallOp::Make(func_name, call); + op->as()->value_slot = i; + op->as()->is_tuple_get = true; + auto name = cinn::UniqName("tuple_" + func_name + "_out" + std::to_string(i) + "_"); + auto ret = ir::Tensor(name, proto->mutable_arg_types[i], shape, shape, op, {}); + mutable_args.push_back(ret); + } + call.As()->write_args = mutable_args; + } + return call; +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/compute.h b/paddle/cinn/lang/compute.h new file mode 100755 index 0000000000000..230a2037c80a0 --- /dev/null +++ b/paddle/cinn/lang/compute.h @@ -0,0 +1,132 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/schedule.h" + +namespace cinn { +namespace lang { + +using compute_handler_t = std::function &)>; +using attr_t = absl::variant; + +//! Compute methods for one to five Vars as arguments. +// @{ +// The shape are constant integers. +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); + +ir::Tensor Compute(const std::vector &domain, + std::function fn, + const std::string &name, + const std::vector &shape = {}); + +ir::Tensor Compute(const std::vector &domain, + compute_handler_t fn, + const std::string &name, + const std::vector &shape = {}); +// @} + +struct ReturnType { + Type type; + std::vector dims; + std::string name; +}; + +/** + * \brief Call a lowered function and return one or more tensors as result. + * + * A lowered function is generated by lang::Lower method. + * + * TODO(Superjomn) Add a registry (symbol table?) to make return result inference automatically. + * + * @param func_name The name of the function to call. + * @param args The readonly arguments(while the mutable tensors are return result). + * @param return_types The types of the return values. + * @return Return one or more tensors as result. + */ +std::vector CallLowered(const std::string &func_name, + const std::vector &args, + const std::vector &return_types); + +/** + * \brief Call an external function and get some tensors as result. + * + * There are two kinds of extern functions distinguished by the return type. + * + * 1. Void, there are one or more mutable tensors in the argument list. + * \code + * Tensor tuple = Compute({M}, []() { return CallExtern("mkl_gemm", {X, W}); }); + * \endcode + * + * To support returning multiple value one time, we include the tuple concept, it is a Tensor with CallOp marked with + * value_offset(from 0 to num_returns-1). + * + * 2. POD value, return an expression directly, and it can be inline expand in following computations. + * \code + * Tensor tanh_out = Compute({M}, [](Var i) { return CallExtern("tanh", X(i)); }); + * \endcode + * + * Will generate something like + * + * \code + * for (i) { + * gemm_mkl(X[i], gemm_out[i]) + * } + * \endcode + * + * @param func_name The name of the function to call. + * @param args The readonly arguments(while there should be only one tensor as result). + * @param attrs The readonly attrs. + */ +Expr CallExtern(const std::string &func_name, + const std::vector &args, + const std::map &attrs = {}); + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/compute_test.cc b/paddle/cinn/lang/compute_test.cc new file mode 100644 index 0000000000000..cca239df92fbc --- /dev/null +++ b/paddle/cinn/lang/compute_test.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/compute.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/buffer.h" +#include "cinn/lang/placeholder.h" + +namespace cinn { +namespace lang { + +TEST(Call, basic) { + Expr M(100); + + Placeholder x("x", {M, Expr(10)}); + Placeholder y("y", {M, Expr(10)}); + + std::vector return_types({{Float(32), std::vector{{M, Expr(20)}}, "C"}}); + auto tensors = CallLowered("lowered_fun0", {Expr(x), Expr(y)}, return_types); +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc new file mode 100755 index 0000000000000..5781c69e3b853 --- /dev/null +++ b/paddle/cinn/lang/lower.cc @@ -0,0 +1,302 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/lower.h" + +#include +#include +#include +#include +#include +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/lower_impl.h" +#include "cinn/optim/optimize.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace lang { + +using ir::Tensor; +using poly::Stage; + +std::vector GetArgs(const Expr& func_body, const std::vector& input_output_nodes) { + std::vector res; + std::map> name2loads; + std::map> name2stores; + auto load_or_store_nodes = ir::CollectIRNodesWithoutTensor( + func_body, [&](const Expr* x) { return x->As() || x->As(); }); + + for (auto&& e : load_or_store_nodes) { + if (e.As()) { + auto&& tensor_name = e.As()->tensor.as_tensor()->name; + name2loads[tensor_name].insert(e.As()); + } else { // Store node + auto&& tensor_name = e.As()->tensor.as_tensor()->name; + name2stores[tensor_name].insert(e.As()); + } + } + + for (auto&& node_name : input_output_nodes) { + auto load_it = name2loads.find(node_name); + auto store_it = name2stores.find(node_name); + // if a node is ir::Load and also ir::Store, then process it as a ir::Store in priority. + if (store_it != name2stores.end()) { // + for (auto&& node : store_it->second) { + const auto* tensor = node->tensor.as_tensor(); + if (tensor->buffer.defined()) { + res.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + break; + } + } + } else if (load_it != name2loads.end()) { + for (auto&& node : load_it->second) { + const auto* tensor = node->tensor.as_tensor(); + if (tensor->buffer.defined()) { + res.emplace_back(tensor->buffer, ir::Argument::IO::kInput); + break; + } + } + } + } + + if (VLOG_IS_ON(3)) { + for (auto& i : input_output_nodes) VLOG(3) << "In input_output_nodes, arg has : " << i; + for (auto& i : res) VLOG(3) << "In res, arg has : " << i.name(); + } + return res; +} + +//! Collect the temporary tensors from a computational graph. +std::vector GetTempBuffers(const std::vector& tensor_args, + const poly::StageMap& stage_map, + Expr body) { + std::unordered_set tensor_arg_names; + std::unordered_set buffer_arg_names; + for (auto& tensor : tensor_args) { + tensor_arg_names.insert(tensor->name); + if (tensor->buffer.defined()) { + buffer_arg_names.insert(tensor->buffer->name); + } + } + std::map name_to_buffer; // used to avoid duplication. + + auto all_temp_tensors = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && + (!stage_map->Lookup(x->as_tensor()->name) || !stage_map[x->as_tensor()]->inlined()) && + ((!buffer_arg_names.count(x->as_tensor()->buffer->name) && !tensor_arg_names.count(x->as_tensor()->name)) || + utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); + }); + for (auto& e : all_temp_tensors) { + auto buffer_name = e.as_tensor()->buffer->name; + if (!name_to_buffer.count(buffer_name)) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } else { + if (e.as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } + } + } + // visit the ir body and update the map of name_to_buffer + auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + if (x->as_tensor() && x->as_tensor()->buffer.defined()) { + auto buffer_name = x->as_tensor()->buffer->name; + if (name_to_buffer.count(buffer_name) && x->as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + name_to_buffer[buffer_name] = x->as_tensor()->buffer; + } + } + return x->as_tensor() && x->as_tensor()->buffer.defined(); + }); + + std::vector temp_buffers; + for (auto& i : name_to_buffer) temp_buffers.push_back(i.second); + return temp_buffers; +} + +//! Collect the temporary tensors from a computational graph. +std::vector GetTempBuffers(const std::vector& args, Expr body) { + std::unordered_set buffer_arg_names; + for (auto& a : args) { + if (a.is_buffer()) { + buffer_arg_names.insert(a.name()); + } + } + std::map name_to_buffer; // used to avoid duplication. + + auto all_temp_tensors = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + return x->as_tensor() && x->as_tensor()->buffer.defined() && + (!buffer_arg_names.count(x->as_tensor()->buffer->name) || + utils::Endswith(x->as_tensor()->buffer->name, "temp_buffer")); + }); + for (auto& e : all_temp_tensors) { + auto buffer_name = e.as_tensor()->buffer->name; + if (!name_to_buffer.count(buffer_name)) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } else { + if (e.as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + name_to_buffer[buffer_name] = e.as_tensor()->buffer; + } + } + } + // visit the ir body and update the map of name_to_buffer + auto update_map = ir::CollectIRNodesWithoutTensor(body, [&](const Expr* x) { + if (x->as_tensor() && x->as_tensor()->buffer.defined()) { + auto buffer_name = x->as_tensor()->buffer->name; + if (name_to_buffer.count(buffer_name) && x->as_tensor()->buffer->numel() < name_to_buffer[buffer_name]->numel()) { + name_to_buffer[buffer_name] = x->as_tensor()->buffer; + } + } + return x->as_tensor() && x->as_tensor()->buffer.defined(); + }); + + std::vector temp_buffers; + for (auto& i : name_to_buffer) temp_buffers.push_back(i.second); + return temp_buffers; +} + +std::set CollectTempTensorsFromCtrlDepends(StageMap stages, const std::vector& tensor_args) { + std::set res; + for (auto& stage : stages) { + res.emplace(ir::Tensor(stage.second->tensor())); + res.insert(stage.second->ctrl_depends().begin(), stage.second->ctrl_depends().end()); + } + for (auto& t : tensor_args) { + if (res.count(t)) res.erase(t); + } + return res; +} + +void InitReduceTensor(StageMap stages, const Tensor& tensor, const Target& target) { + if (tensor->is_reduce_tensor() && !tensor->IsReduceInited(stages)) { + tensor->InitReduction(stages, target); + } + auto uninited_reduce_tensors = ir::CollectIRNodes(tensor->body(), [&](const Expr* x) { + return x && x->defined() && x->as_tensor() && x->as_tensor()->is_reduce_tensor() && + !x->as_tensor()->IsReduceInited(stages); + }); + for (auto& t : uninited_reduce_tensors) { + VLOG(3) << "Init reduce tensor: " << t.as_tensor()->name; + t.as_tensor()->InitReduction(stages, target); + } +} + +ir::LoweredFunc Lower(const std::string& name, + StageMap stages, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& temp_tensors, + Module::Builder* b, + const Target& target, + bool support_ir_schedule) { + // Init the reduce tensors first before any process. + for (auto& t : tensor_args) InitReduceTensor(stages, t, target); + for (auto& t : temp_tensors) InitReduceTensor(stages, t, target); + // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors + auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args); + ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end()); + auto lower_impl_instance = detail::LowerImpl(name, + stages, + tensor_args, + scalar_args, + std::vector(ctrl_deps.begin(), ctrl_deps.end()), + target, + support_ir_schedule); + auto result = lower_impl_instance(); + std::vector return_value; + for (auto& res : result) { + auto temp_buffers = GetTempBuffers(tensor_args, stages, res->body); + if (b) { + for (auto& temp_buffer : temp_buffers) { + b->AddBuffer(temp_buffer); + } + } + { + for (auto& stage : stages) { + if (stage.second->IfCudaBind()) { + res->device_api = ir::DeviceAPI::GPU; + break; + } + } + if (target == common::DefaultNVGPUTarget()) { + res->device_api = ir::DeviceAPI::GPU; + } + } + if (b) { + b->AddFunction(res); + } + res->temp_bufs = temp_buffers; + return_value.push_back(res); + } + return return_value[0]; +} + +std::vector LowerVec(const std::string& name, + StageMap stages, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& temp_tensors, + Module::Builder* b, + const Target& target, + bool support_ir_schedule) { + // Init the reduce tensors first before any process. + for (auto& t : tensor_args) InitReduceTensor(stages, t, target); + for (auto& t : temp_tensors) InitReduceTensor(stages, t, target); + // Merge the ctrl_deps with the given temp_tensors ang get a new temp_tensors + auto ctrl_deps = CollectTempTensorsFromCtrlDepends(stages, tensor_args); + ctrl_deps.insert(temp_tensors.begin(), temp_tensors.end()); + auto lower_impl_instance = detail::LowerImpl(name, + stages, + tensor_args, + scalar_args, + std::vector(ctrl_deps.begin(), ctrl_deps.end()), + target, + support_ir_schedule); + // return vectorof ir::LoweredFunc. + auto result = lower_impl_instance(); + std::vector return_value; + for (auto& res : result) { + auto temp_buffers = GetTempBuffers(tensor_args, stages, res->body); + if (b) { + for (auto& temp_buffer : temp_buffers) { + b->AddBuffer(temp_buffer); + } + } + + { // set function device_api + for (auto& stage : stages) { + if (stage.second->IfCudaBind()) { + res->device_api = ir::DeviceAPI::GPU; + break; + } + } + + if (target == common::DefaultNVGPUTarget()) { + res->device_api = ir::DeviceAPI::GPU; + } + } + if (b) { + b->AddFunction(res); + } + + res->temp_bufs = temp_buffers; + + return_value.push_back(res); + } + return return_value; +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower.h b/paddle/cinn/lang/lower.h new file mode 100644 index 0000000000000..d20adad843174 --- /dev/null +++ b/paddle/cinn/lang/lower.h @@ -0,0 +1,85 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Lower lowerise the statements to LoweredFuncs. + */ + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/ir/module.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/packed_func.h" +#include "cinn/poly/schedule.h" + +namespace cinn { +namespace lang { +using ir::Tensor; +using poly::StageMap; + +/** + * \brief Lower the computation of \p tensor_args and \p scalar_args to a LoweredFunc. + * @param name The name of the function. + * @param tensor_args The tensor arguments, where the computation logic locates. + * @param scalar_args The scalar arguments, indicate some dimensions. + * @param temp_tensors The temporary tensors(buffers) used in the body. + * @param b The module this function belongs to. + * @return A LoweredFunc, whose name is \p name, the argument list is the concatenation of \p tensor_args and \p + * scalar_args. + */ +ir::LoweredFunc Lower(const std::string &name, + StageMap stages, + const std::vector &tensor_args, + const std::vector &scalar_args = {}, + const std::vector &temp_tensors = {}, + ir::Module::Builder *b = nullptr, + const Target &target = common::DefaultHostTarget(), + bool support_ir_schedule = false); + +/** + * \brief Lower the computation of \p tensor_args and \p scalar_args to a vector of LoweredFuncs. Each schedule group + * forms a LoweredFunc. + * @param name The name of the function. + * @param tensor_args The tensor arguments, where the computation logic locates. + * @param scalar_args The scalar arguments, indicate some dimensions. + * @param temp_tensors The temporary tensors(buffers) used in the body. + * @param b The module this function belongs to. + * @return A vector of LoweredFuncs, whose name is \p name, name + "_1", name + "_2"... The argument list is deduced + * from the expression of each func. + */ +std::vector LowerVec(const std::string &name, + StageMap stages, + const std::vector &tensor_args, + const std::vector &scalar_args = {}, + const std::vector &temp_tensors = {}, + ir::Module::Builder *b = nullptr, + const Target &target = common::DefaultHostTarget(), + bool support_ir_schedule = false); + +std::vector GetArgs(const Expr &func_body, const std::vector &input_output_nodes); + +//! Collect the temporary tensors from a computational graph. +std::vector GetTempBuffers(const std::vector &tensor_args, + const poly::StageMap &stage_map, + Expr body); + +//! Collect the temporary tensors from a computational graph. +std::vector GetTempBuffers(const std::vector &args, Expr body); + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc new file mode 100644 index 0000000000000..e839fc8ef0507 --- /dev/null +++ b/paddle/cinn/lang/lower_impl.cc @@ -0,0 +1,791 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/lower_impl.h" + +#include +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/common/context.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/tensor.h" +#include "cinn/optim/remove_nested_block.h" +#include "cinn/optim/replace_var_with_expr.h" +#include "cinn/optim/transform_polyfor_to_for.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace lang { +namespace detail { + +void CheckNoIslCallRemains(Expr* expr) { + auto isl_calls = ir::CollectIRNodes( + *expr, [](const Expr* expr) { return expr->As() && expr->As()->is_isl_call(); }); +#ifdef CINN_DEBUG + for (auto& item : isl_calls) { + LOG(ERROR) << "ISL call: " << item; + } +#endif + if (!isl_calls.empty()) { + LOG(WARNING) << "Some ISL call nodes remained, get " << isl_calls.size() << " isl_calls, the first one is " + << *isl_calls.begin(); + } +} + +void BindBuffer(StageMap& stages) { + absl::flat_hash_map tensor_map; + for (auto& stage : stages) { + tensor_map[stage.second->tensor()->name] = stage.second->tensor(); + } + for (auto& stage : stages) { + if (!stage.second->tensor()->buffer.defined() && !stage.second->meta.tensors_to_share_buffer_with.empty()) { + for (auto& str : stage.second->meta.tensors_to_share_buffer_with) { + if (tensor_map[str]->buffer.defined()) { + auto edited_shape = tensor_map[str]->buffer->shape; + stage.second->tensor()->Bind(tensor_map[str]->buffer); + tensor_map[str]->buffer->shape = edited_shape; + VLOG(3) << "Tensor " << stage.second->tensor()->name << " bind buffer to " << tensor_map[str]->name << " , " + << tensor_map[str]->buffer->name; + } + } + } + } +} + +Expr LowerGroup(const poly::ScheduleGroup& group, + const std::map& tuple_to_expr, + std::map* global_tensor_map, + std::unordered_map>& resized_buffer_cache, + StageMap stage_map, + ir::CudaAxisInfo* cuda_axis_info) { + BindBuffer(stage_map); + std::vector stages; + for (auto& node : group.nodes) { + VLOG(1) << "In LowerGroup, node id is: " << node->id(); + if (node->stage->has_expression()) { + stages.push_back(node->stage); + VLOG(1) << "stage expr " << node->stage->expr(); + } else { + VLOG(1) << "stage expression is null: " << node->stage->domain(); + } + } + + if (stages.empty()) return Expr(); + + // get isl generated expression + isl::set context(Context::isl_ctx(), "{:}"); + poly::AstGen gen(context, stages, group); + isl::ast_node ast = gen.Build(); + ir::Expr e; + + // The code where adds length 1 loop back to CINN Expr, if you do not want to + // add back, call poly::IslAstNodeToCinnExpr(ast, &e) instead of + // poly::IslAstNodeToCinnExpr(ast, gen.domain(), &e); + + VLOG(6) << "before ast to expr"; + // poly::IslAstNodeToCinnExpr(ast, &e); + poly::IslAstNodeToCinnExpr(ast, gen.domain(), &e); + // now we get a workable expression, but the statement are something like `B(((16 * po0) + po1), po2)`, we need to + // transform this to some realworld statement in CINN. + + VLOG(1) << "ast to expr: \n" << e << std::endl; + + // replace isl call to the corresponding CINN statement, we need to replace the axis at the same time. + for (auto& statement : tuple_to_expr) { + VLOG(2) << "LowerGroup working on statement: " << statement.first; + if (!gen.ContainsStatement(statement.first)) continue; + // the axis_ast_map contains the axis from the original (like `i`) to the transformed (like `i+3`). + auto axis_expr_map = gen.axis2expr(statement.first); + for (auto& item : axis_expr_map) { + VLOG(4) << "statement ast map axis [" << item.first << "] to " + << "[" << item.second << "]"; + } + + // the original CINN statements. + Expr statement_candi_expr = tuple_to_expr.at(statement.first); + + VLOG(3) << "replacing " << statement.first << " to " << statement_candi_expr; + optim::ReplaceIslCallWithExpr(&e, statement.first, statement_candi_expr, axis_expr_map); + } + CheckNoIslCallRemains(&e); + + // Update global_tensor_map + for (auto& e : stage_map) { + if (!global_tensor_map->count(e.second->id())) { + (*global_tensor_map)[e.second->id()] = ir::Tensor(e.second->tensor()); + } + } + + // mark vectorize. + { + std::map vectorizes; + for (auto& node : group.nodes) { + if (node->stage->vectorize_info().valid()) { + vectorizes[node->stage->id()] = node->stage->vectorize_info(); + } + } + MarkVectorizeMutator mutator(vectorizes); + mutator(&e); + } + + // mark unroll. + { + std::map> unrolls; + for (auto& node : group.nodes) { + if (!node->stage->unroll_info().empty()) { + unrolls[node->stage->id()] = node->stage->unroll_info(); + } + } + MarkUnrollMutator mutator(unrolls); + mutator(&e); + } + + // mark parallel. + { + std::map> parallels; + for (auto& node : group.nodes) { + if (!node->stage->parallel_info().empty()) { + parallels[node->stage->id()] = node->stage->parallel_info(); + } + } + MarkParallelMutator mutator(parallels); + mutator(&e); + } + + return e; +} + +bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage) { + if (stage->inlined()) return false; + if (stage) { + for (auto& info : stage->forloop_infos()) { + if (info.second.device == ir::DeviceAPI::GPU) { + return true; + } + } + } + return false; +} + +const char* CompuGraphNode::__type_info__ = "ComputeGraphNode"; +const char* CompuGraphNode::type_info() const { return __type_info__; } +std::string CompuGraphNode::id() const { + CHECK(tensor.defined()); + return tensor->name; +} + +/** + * \brief Add nodes to graph with dependencies. + * We create a computation graph based on the tensor dependency relations. + * NOTE The graph will contain the inline tensors so that the dependency will be reserved. + * @param graph The graph + * @param t The tensor. + * @param stages The stage map. + */ +void CreateCompGraphWithInlineTensors(common::Graph* graph, + const ir::Tensor& t, + StageMap stages, + std::set* visited) { + if (visited->count(t)) return; + common::GraphNode* t_node = graph->RetrieveNode(t->name); + if (!t_node) { + t_node = graph->RegisterNode(t->name, new CompuGraphNode(t)); + } + + visited->insert(t); + + // collect dependency tensors of t + // here we just collect the tensors in Load nodes + // NOTE there may be some other cases. + auto deps = ir::CollectLoadTensors(t->body(), [](const Expr* x) { return x->as_tensor(); }); + for (const auto& dep : deps) { + auto e_tensor = dep.as_tensor_ref(); + auto* e_node = graph->RetrieveNode(e_tensor->name); + if (!e_node) { + e_node = graph->RegisterNode(e_tensor->name, new CompuGraphNode(e_tensor)); + } + e_node->Controls(t_node); + if (!visited->count(e_tensor)) { + CreateCompGraphWithInlineTensors(graph, e_tensor, stages, visited); + } + } +} + +std::unique_ptr CreateCompGraphWithInlineTensorHidden(const std::vector& tensors, + StageMap stages) { + // create a graph with inline tensor first. + std::unique_ptr graph(new common::Graph); + std::set visited; + for (auto& t : tensors) { + CreateCompGraphWithInlineTensors(graph.get(), t, stages, &visited); + } + + // greedy remove the inline tensor, each time merge the inputs of an inline tensor to its sink node. + + std::set inline_nodes; + do { + inline_nodes = graph->CollectNodes([&](const common::GraphNode* x) { + auto* comp_node = x->safe_as(); + return stages[comp_node->tensor]->inlined(); + }); + if (inline_nodes.empty()) break; + + /* + * A -> inlined -> B + * C / + * => + * A -> B + * C / + */ + for (auto* inline_node : inline_nodes) { + // remove this node, merge its inputs to the sink nodes. + auto inline_inlinks = inline_node->inlinks(); + auto inline_outlinks = inline_node->outlinks(); + + // unlink the inline node from its inputs and outputs + for (auto& link : inline_inlinks) { + link->source()->UnLinkSingleTo(link->sink()); + } + for (auto& link : inline_outlinks) { + link->source()->UnLinkSingleTo(link->sink()); + } + + // link inline node's input nodes to its output nodes. + for (auto out_edge : inline_outlinks) { + auto* out = out_edge->sink(); + for (auto in_edge : inline_inlinks) { + auto* source = in_edge->source(); + source->LinkTo(out); + } + } + + graph->DropNode(inline_node); + } + } while (!inline_nodes.empty()); + + return graph; +} + +void CompuGraphAddCtrlDepLinks(common::Graph* graph, StageMap stages) { + for (auto& x : graph->nodes()) { + auto* node = x->safe_as(); + CHECK(node); + for (auto& dep : stages[node->tensor]->ctrl_depends()) { + auto* dep_node = graph->RetrieveNode(dep->name); + if (dep_node) { + VLOG(3) << "Add control link: " << dep << " -> " << node->id(); + dep_node->Controls(node); + } + } + } +} + +std::unique_ptr CreateCompGraph(const std::vector& tensors, + StageMap stages, + bool hide_inline) { + if (hide_inline) { + auto graph = CreateCompGraphWithInlineTensorHidden(tensors, stages); + CompuGraphAddCtrlDepLinks(graph.get(), stages); + return graph; + } else { + auto graph = std::make_unique(); + std::set visited; + for (auto& t : tensors) { + CreateCompGraphWithInlineTensors(graph.get(), t, stages, &visited); + } + CompuGraphAddCtrlDepLinks(graph.get(), stages); + return graph; + } +} + +void LowerImpl::CheckArgsUnique() { + for (auto& tensor : tensor_args_) { + CHECK(!stages_[tensor]->inlined()) << "Inline tensor cannot be argument of function"; + if (!tensor->buffer.defined()) { + LOG(ERROR) << "tensor [" << tensor->name << "] buffer is null"; + continue; + } + } +} + +std::vector LowerImpl::GenerateFunctionArgumentList(Expr fn_body) { + CheckArgsUnique(); + + std::vector args; + optim::TensorWriteTeller teller; + teller.Collect(&fn_body); + + std::set arg_names; + + for (auto& scalar : scalar_args_) { + CHECK(!arg_names.count(scalar->name)); + auto* scalar_node = scalar.As(); + CHECK(scalar_node->type().valid()); + arg_names.insert(scalar->name); + + args.emplace_back(scalar, ir::Argument::IO::kInput); + } + + for (auto& tensor : tensor_args_) { + auto* tensor_node = tensor.As(); + bool is_output = teller.IsWrite(tensor->name); + VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + + // avoid duplicate + if (!tensor_node->buffer.defined()) continue; + // if a argument is already marked as kInput, mark it as kOutput and move it to the back. + if (arg_names.count(tensor_node->buffer->name)) { + auto it = std::find_if( + args.begin(), args.end(), [&](const ir::Argument& x) { return x.name() == tensor_node->buffer->name; }); + CHECK(it != args.end()); + if (it->is_input()) { + args.erase(it); + } else if (it->is_output()) { + continue; + } + } + + arg_names.insert(tensor_node->buffer->name); + + auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput; + VLOG(3) << "Collect " << (is_output ? "W" : "R") << " argument " << tensor->buffer->name; + args.emplace_back(tensor_node->buffer, io); + } + + return args; +} +// Generate Function Arguments for splitted kernel. +std::vector LowerImpl::GenFuncArgForSplitKernel(Expr func_iterator, + std::vector temp_tensors) { + CheckArgsUnique(); + + std::vector in_args; + std::vector out_args; + optim::TensorWriteTeller teller; + teller.Collect(&func_iterator); + std::set arg_names; + std::set all_tensor_names; + + for (auto& scalar : scalar_args_) { + CHECK(!arg_names.count(scalar->name)); + auto* scalar_node = scalar.As(); + CHECK(scalar_node->type().valid()); + arg_names.insert(scalar->name); + + in_args.emplace_back(scalar, ir::Argument::IO::kInput); + } + + auto all_tensors = ir::CollectIRNodes( + func_iterator, [&](const Expr* x) { return x->as_tensor() && !stages_[x->as_tensor()]->inlined(); }); + + auto all_vars = ir::CollectIRNodes(func_iterator, [&](const Expr* x) { return x->as_var(); }); + + for (auto& i : all_tensors) { + auto* tensor = i.as_tensor(); + all_tensor_names.insert(tensor->name); + VLOG(3) << "In all_tensors, it has : " << tensor->name; + if (!stages_[tensor]->meta.tensors_to_share_buffer_with.empty()) { + for (auto& i : stages_[tensor]->meta.tensors_to_share_buffer_with) { + all_tensor_names.insert(i); + VLOG(3) << "And its share_buffer_tensor is : " << i; + } + } + } + for (auto& i : all_vars) { + auto* var = i.as_var(); + VLOG(3) << "In all_vars, it has : " << var->name; + } + + for (auto& i : scalar_args_) { + VLOG(3) << "In scalar_args_, var has : " << i->name; + } + + std::set temp_tensor_names; + + for (auto& i : temp_tensors) { + VLOG(3) << "In temp_tensors, it has : " << i->name; + temp_tensor_names.insert(i->name); + } + + for (auto& tensor : tensor_args_) { + VLOG(3) << "In tensor_args_, it has : " << tensor->name; + if (temp_tensor_names.count(tensor->name) > 0) continue; + if (all_tensor_names.count(tensor->name) == 0) continue; + bool is_output = teller.IsWrite(tensor->name); + VLOG(3) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; + + // avoid duplicate + if (!tensor->buffer.defined()) { + VLOG(3) << "tensor->buffer is not defined"; + continue; + } + // if a argument is already marked as kInput, mark it as kOutput and move it to the back. + if (arg_names.count(tensor->buffer->name)) { + auto it = std::find_if( + in_args.begin(), in_args.end(), [&](const ir::Argument& x) { return x.name() == tensor->buffer->name; }); + if (it != in_args.end()) { + in_args.erase(it); + } else { + continue; + } + } + + arg_names.insert(tensor->buffer->name); + + auto io = is_output ? ir::Argument::IO::kOutput : ir::Argument::IO::kInput; + if (io == ir::Argument::IO::kInput) + in_args.emplace_back(tensor->buffer, io); + else + out_args.emplace_back(tensor->buffer, io); + } + if (out_args.empty()) { + for (auto& i : all_tensors) { + auto* tensor = i.as_tensor(); + VLOG(3) << "Tensor " << tensor->name; + if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) { + bool is_output = teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name); + if (is_output) out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + } + } + } + + std::vector args(in_args.begin(), in_args.end()); + args.insert(std::end(args), out_args.begin(), out_args.end()); + return args; +} + +std::vector LowerImpl::CollectTemporaryTensors() { + // a temporary should be in the comp_graph but not contained in the tensor_args. + absl::flat_hash_map tensor_arg_map = GenTensorArgMap(); + absl::flat_hash_map temp_tensor_map; + + for (auto* node : compu_graph_->nodes()) { + auto* cnode = node->safe_as(); + CHECK(cnode); + if (!tensor_arg_map.count(cnode->tensor->name)) { + temp_tensor_map[cnode->tensor->name] = cnode->tensor; + } + } + + std::vector temp_tensors; + std::transform(temp_tensor_map.begin(), + temp_tensor_map.end(), + std::back_inserter(temp_tensors), + [&](const decltype(temp_tensor_map)::value_type& x) { return x.second; }); + return temp_tensors; +} + +absl::flat_hash_map LowerImpl::GenTensorArgMap() { + absl::flat_hash_map map; + for (auto& t : tensor_args_) { + map[t->name] = t; + } + return map; +} + +absl::flat_hash_map LowerImpl::GenAllTensorMap() { + absl::flat_hash_map map; + for (auto& t : CollectAllTensors()) { + map[t->name] = t; + } + return map; +} + +std::vector LowerImpl::operator()() { + std::vector stages; + std::map all_tensor_map; + for (auto& t : CollectAllTensors()) { + all_tensor_map[t->name] = t; + if (!stages_[t]->inlined()) stages.push_back(stages_[t]); + } + + auto deps = CollectExtraDependencies(); + auto schedule = poly::CreateSchedule( + stages, poly::ScheduleKind::Poly, std::vector>(deps.begin(), deps.end())); + auto func_body = GenerateFunctionBody(schedule.get()); + + std::vector result; + int num_func = 0; + for (auto& func_iterator : func_body) { + if (support_ir_schedule_) { + // add ScheduleBlockRealize + func_iterator = ir::ScheduleBlockRealize::Make( + {}, ir::ScheduleBlock::Make({}, {}, {}, common::UniqName("root"), func_iterator)); + } + std::set temp_tensor_names; + for (auto& t : temp_tensor_args_) temp_tensor_names.insert(t->name); + + auto tensor_map = + optim::InitialAssignBuffer(&func_iterator, stages_, all_tensor_map, comp_graph(), temp_tensor_names); + // copy the tensor(with buffer assigned) back to func's args. + { + for (auto& arg : tensor_args_) { + if (arg->is_placeholder_node()) continue; + if (arg->buffer.defined()) continue; + if (arg->body().As() && arg->body().type().is_void()) continue; // extern call + if (tensor_map.find(arg->name) == tensor_map.end()) { + LOG(INFO) << "Didn't find arg tensor " << arg->name << "in tensor_map.\n" + << "The function is " << fn_name_ << "\nAnd all the arg tensors are:\n"; + for (auto& i : tensor_args_) { + LOG(INFO) << i->name; + } + LOG(FATAL) << "Fatal Error!"; + } + Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer; + } + } + auto store_exprs = ir::CollectIRNodes(func_iterator, [](const Expr* x) { return x->As(); }); + std::vector new_temp_tensors; + for (auto& expr : store_exprs) { + auto* store_node = expr.As(); + CHECK(store_node); + auto* tensor = store_node->tensor.As(); + CHECK(tensor); + VLOG(3) << "In store_exprs, its name is : " << tensor->name; + CHECK(tensor->buffer.defined()); + if (tensor->buffer->memory_type != ir::MemoryType::Heap) { + new_temp_tensors.push_back(store_node->tensor.as_tensor_ref()); + } + } + + auto func_temp_tensors = CollectTemporaryTensors(); + std::vector temp_buffers; + std::unordered_set buffer_name_set; + // TODO(Superjomn) write buffer latter. + + if (target_ == common::DefaultNVGPUTarget()) { + for (auto& t : new_temp_tensors) { + if (!tensor_map.count(t->name)) continue; + auto& tt = tensor_map.at(t->name); + if (tt->buffer.defined() && !buffer_name_set.count(tt->buffer->name)) { + temp_buffers.push_back(tt->buffer); + buffer_name_set.insert(tt->buffer->name); + } + } + } else { + for (auto& t : func_temp_tensors) { + if (!tensor_map.count(t->name)) continue; + auto& tt = tensor_map.at(t->name); + if (tt->buffer.defined() && !buffer_name_set.count(tt->buffer->name)) { + temp_buffers.push_back(tt->buffer); + buffer_name_set.insert(tt->buffer->name); + } + } + } + + ir::LoweredFunc func; + if (target_ == common::DefaultNVGPUTarget()) { + auto func_args2 = GenFuncArgForSplitKernel(func_iterator, new_temp_tensors); + std::string new_fn_name = fn_name_; + if (num_func > 0) { + new_fn_name += "_" + std::to_string(num_func); + } + VLOG(3) << "Making func :" << new_fn_name; + for (auto& i : func_args2) { + VLOG(3) << "func_args2 is : " << i.name(); + } + for (auto& i : temp_buffers) { + VLOG(3) << "temp_buffers is : " << i->name; + } + func = ir::_LoweredFunc_::Make(new_fn_name, func_args2, func_iterator, temp_buffers); + } else { + auto func_args = GenerateFunctionArgumentList(func_iterator); + func = ir::_LoweredFunc_::Make(fn_name_, func_args, func_iterator, temp_buffers); + } + + if (support_ir_schedule_) { + optim::TransformPolyForToFor(&func->body); + optim::RemoveNestedBlock(&func->body); + func->body = ir::Block::Make({func->body}); + result.push_back(ir::LoweredFunc(func.get())); + num_func++; + } else { + optim::ComputeInlineExpand(&func->body, stages_, &all_tensor_map); + auto res = + optim::Optimize(func, target_, FLAGS_cinn_runtime_display_debug_info, /* remove_gpu_for_loops = */ false); + + if (cuda_axis_info_.size() > num_func && cuda_axis_info_[num_func].valid()) { + auto* res_func = res.as_lowered_func(); + res_func->cuda_axis_info = cuda_axis_info_[num_func]; + } + result.push_back(ir::LoweredFunc(res.get())); + num_func++; + } + } + return result; +} + +std::vector LowerImpl::CollectAllTensors() { + std::vector tensors; + auto topo_order = compu_graph_->topological_order(); // NOLINT + auto& nodes = std::get<0>(topo_order); + auto& edges = std::get<1>(topo_order); + for (auto* node : nodes) { + auto* cnode = node->safe_as(); + CHECK(cnode); + tensors.push_back(cnode->tensor); + } + return tensors; +} + +std::set> LowerImpl::CollectExtraDependencies() const { + std::set> deps; + for (auto* node : compu_graph_->nodes()) { + auto* cnode = node->safe_as(); + CHECK(cnode); + for (auto& dep : stages_[cnode->tensor]->ctrl_depends()) { + deps.emplace(dep->name, cnode->tensor->name); + } + } + return deps; +} + +std::vector LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) { + // generate the expressions for each group. + std::vector exprs; + std::vector result; + auto tensor_map = GenAllTensorMap(); + std::map tuple_to_expr; + CHECK(!schedule->groups.empty()) << "no group is generated"; + + std::map global_tensor_map; + std::unordered_map> resized_buffer_cache; + + for (auto& group : schedule->groups) { + CHECK_GT(group.nodes.size(), 0) << "group is empty"; + bool all_temp_tensor = true; + for (auto& node : group.nodes) { + if (!tensor_map.count(node->id())) { + VLOG(2) << "tensor_map doesn't count " << node->id(); + continue; + } + auto& tensor = tensor_map[node->id()]; + if (!tensor->has_expression()) continue; + all_temp_tensor = + all_temp_tensor && (stages_[tensor]->inlined() || + (tensor->buffer.defined() && (tensor->buffer->memory_type == ir::MemoryType::GPUShared || + tensor->buffer->memory_type == ir::MemoryType::GPULocal))); + auto store_body = tensor->tensor_store_expanded_body(); + if (support_ir_schedule_) { + // add schedule block of tensor computation for schedule IR + int var_counts = tensor->shape.size() + tensor->reduce_axis.size(); + std::vector int_shape; + VLOG(3) << "Tensor " << tensor->name << "'s shape is : " << utils::Join(tensor->shape, ","); + for (auto& expr : tensor->shape) { + CHECK(expr.is_constant()); + int_shape.push_back((int)expr.get_constant()); + } + for (auto& var : tensor->reduce_axis) { + CHECK(var->lower_bound.defined()); + CHECK(var->upper_bound.defined()); + CHECK(common::is_zero(var->lower_bound)); + CHECK(var->upper_bound.is_constant()); + int_shape.push_back((int)var->upper_bound.get_constant()); + } + // create block itervars, i0,i1... + std::vector block_vars; + std::vector iter_values; + std::vector axis_vars = common::GenDefaultAxis(tensor->shape.size()); + // bind var_values + axis_vars.insert(axis_vars.end(), tensor->reduce_axis.begin(), tensor->reduce_axis.end()); + for (int i = 0; i < var_counts; i++) { + block_vars.push_back(Var(Expr(0), Expr(int_shape[i]), cinn::UniqName("i" + std::to_string(i)), false)); + if (i >= tensor->shape.size()) { + block_vars[i]->is_reduce_axis = true; + axis_vars[i]->is_reduce_axis = true; + } + iter_values.push_back(axis_vars[i]); + // replace store's indice + VLOG(3) << "replace axis_var " << axis_vars[i]->name << " to block_var " << block_vars[i]; + optim::ReplaceVarWithExpr(&store_body, axis_vars[i], block_vars[i]); + } + store_body = ir::ScheduleBlockRealize::Make( + iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, tensor->name, store_body)); + // iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, common::UniqName(tensor->name), store_body)); + VLOG(3) << "store body\n" << store_body; + } + tuple_to_expr[tensor->name] = store_body; + } + + ir::CudaAxisInfo temp_cuda_axis_info; + Expr group_expr = + LowerGroup(group, tuple_to_expr, &global_tensor_map, resized_buffer_cache, stages_, &temp_cuda_axis_info); + + if (group_expr.defined()) { + cuda_axis_info_.emplace_back(std::move(temp_cuda_axis_info)); + if (target_ == common::DefaultNVGPUTarget() && !all_temp_tensor) { + exprs.push_back(group_expr); + Expr body = ir::Block::Make(exprs); + result.push_back(body); + exprs.clear(); + } else { + exprs.push_back(group_expr); + } + } + } + if (target_ == common::DefaultHostTarget()) { + Expr body = ir::Block::Make(exprs); + result.push_back(body); + exprs.clear(); + } else if (!exprs.empty()) { + Expr body = ir::Block::Make(exprs); + result.push_back(body); + exprs.clear(); + } + + return result; +} + +LowerImpl::LowerImpl(const std::string& fn_name, + StageMap stages, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& temp_tensor_args, + const Target& target, + bool support_ir_schedule) + : fn_name_(fn_name), + stages_(stages), + tensor_args_(tensor_args), + scalar_args_(scalar_args), + temp_tensor_args_(temp_tensor_args), + target_(target), + support_ir_schedule_(support_ir_schedule) { + { // Initialize the graph + std::vector tensors(tensor_args.begin(), tensor_args.end()); + tensors.insert(std::end(tensors), temp_tensor_args.begin(), temp_tensor_args.end()); + + compu_graph_ = CreateCompGraph(tensors, stages, false /*inline_hide*/); + + VLOG(1) << "compute_graph:\n" << compu_graph_->Visualize(); + } + + // Todo: Here insert auto syncthreads() @haoze + + { // update schedule. + std::vector tensors(tensor_args.begin(), tensor_args.end()); + tensors.insert(std::end(tensors), temp_tensor_args_.begin(), temp_tensor_args_.end()); + compu_graph_ = CreateCompGraph(tensors, stages, true /*inline_hide*/); + + VLOG(1) << "Computation Graph:\n" << compu_graph_->Visualize(); + } +} + +} // namespace detail +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h new file mode 100644 index 0000000000000..923e04c90d46f --- /dev/null +++ b/paddle/cinn/lang/lower_impl.h @@ -0,0 +1,304 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cinn/common/graph_utils.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/buffer_assign.h" +#include "cinn/optim/compute_inline_expand.h" +#include "cinn/optim/fold_cinn_call_arguments.h" +#include "cinn/optim/optimize.h" +#include "cinn/optim/remove_nested_block.h" +#include "cinn/optim/replace_call_with_expr.h" +#include "cinn/optim/tensor_write_tell.h" +#include "cinn/optim/transform_gpu_forloop.h" +#include "cinn/optim/transform_polyfor_to_for.h" +#include "cinn/poly/ast_gen.h" + +namespace cinn { + +namespace poly { +class Stage; +} // namespace poly + +namespace lang { +namespace detail { + +/** + * After the AstGen build the forloop from isl exprs, all the ISL Call nodes should be mapped to the corresponding CINN + * expressions, there should be no remaining. + */ +void CheckNoIslCallRemains(const Expr* expr); + +/** + * \brief Lower a single group of nodes. + * + * We partition the whole computation of a function into several groups, each group is a basic element for ISL + * polyhedral computation, that is, we transform a group into a isl domain and schedule, and generate ast latter. + * + * @param group A single schedule group containing several Stages and the scheduling order. + * @param tuple_to_expr A map from isl set tuple name to CINN expressions. + */ +Expr LowerGroup(const poly::ScheduleGroup& group, + const std::map& tuple_to_expr, + std::map* global_tensor_map, + std::unordered_set& resized_buffer, + StageMap stage_map, + ir::CudaAxisInfo* cuda_axis_info = nullptr); + +/** + * A Computation graph node. + */ +struct CompuGraphNode : public common::GraphNode { + explicit CompuGraphNode(ir::Tensor tensor) : tensor(tensor) {} + + ir::Tensor tensor; + + std::string id() const override; + const char* type_info() const override; + static const char* __type_info__; +}; + +/** + * \brief Create a computation graph using a tensor set. + * It will deduce the temporary tensors not in the \p tensors. + * It consider the `extra_depend_stages` stored in tensor.stage. + * + * @param tensors the input/output tensors of a computation. + * @param hide_inline hide inline tensor nodes. + * @return a graph. + */ +std::unique_ptr CreateCompGraph(const std::vector& tensors, + StageMap stages, + bool hide_inline = false); + +class LowerImpl { + public: + /** + * @param fn_name the name of the final output function. + * @param tensor_args the tensor arguments for the function + * @param scalar_args the scalar arguments for the function + * @param temp_tensor_args the extra temporary tensor arguments + * + * The \p tensor_args contains both input and output tensors. + */ + LowerImpl(const std::string& fn_name, + StageMap stages, + const std::vector& tensor_args, + const std::vector& scalar_args, + const std::vector& temp_tensor_args = {}, + const Target& target = common::DefaultHostTarget(), + bool support_ir_schedule = false); + + std::vector operator()(); + + /** + * Get the computational graph. + */ + const common::Graph* comp_graph() const { return compu_graph_.get(); } + + /** + * \brief generate the argument list of the final output function. + * We put the scalar_args in front of tensor_args, e.g. get tensor_args{A,B}, scalar_args{m}, the final argument list + * is {m, A, B}, the input and output tensor can be mixed in the tensor_args, the kInput and kOutput token will deduce + * from their usage in the computation. + */ + std::vector GenerateFunctionArgumentList(Expr fn_body); + + std::vector GenFuncArgForSplitKernel(Expr func_iterator, std::vector temp_tensors); + + /** + * \brief generate the body expression of the final output function. + */ + std::vector GenerateFunctionBody(const poly::Schedule* schedule); + + private: + /** + * \brief Collect the temporary tensors. + * A temporary tensor is one that is in the computation graph, not inlined and not in the tensor_args(similar to a + * temporary variable inside function). + */ + std::vector CollectTemporaryTensors(); + + /** + * \brief Check both the tensor_args and sclar_args not contain duplication (different arguemnt with the same name). + */ + void CheckArgsUnique(); + + /** + * \brief Get a map, for each tensor in the tensor_args, map from name to itself. + */ + inline absl::flat_hash_map GenTensorArgMap(); + + /** + * \brief Get a map, for each tensor in the computation graph, map from name to itself. + */ + inline absl::flat_hash_map GenAllTensorMap(); + + /** + * \brief Get all the tensors, including the input, output and temporary ones. + */ + std::vector CollectAllTensors(); + + /** + * \brief Collect the extra dependencies between tensors. + * + * The extra dependencies include + * 1. the control deps in Stage. + * + * TODO(Superjomn) remove the field `extra_depend_stages` + */ + std::set> CollectExtraDependencies() const; + + private: + const std::string& fn_name_; + const std::vector& tensor_args_; + const std::vector& scalar_args_; + std::vector temp_tensor_args_; + Target target_; + + StageMap stages_; + + //! A computation graph generated from the tensor_args and scalar_args. + std::unique_ptr compu_graph_; + + //! CUDA axis info for this function. + std::vector cuda_axis_info_; + + bool support_ir_schedule_ = false; +}; + +/** + * \brief Tell whether a tensor contains some GPU related information, such some schedule. + */ +bool TensorContainsGPUInfo(ir::Tensor t, poly::Stage* stage); + +/** + * Mark the PolyFor as Vectorized if it is scheduled Vectorize in Stage. + */ +struct MarkVectorizeMutator : public ir::IRMutator { + const std::map& vectorizes; + + explicit MarkVectorizeMutator(const std::map& vectorizes) + : vectorizes(vectorizes) {} + + void operator()(Expr* expr) { ir::IRMutator::Visit(expr, expr); } + + // NOTE This mutator takes PolyFor as input, not For. + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + forloop_stack.push_back(node); + ir::IRMutator::Visit(op, expr); + forloop_stack.pop_back(); + } + + // each statement in ISL is bound to a Store node. + void Visit(const ir::Store* op, Expr* expr) override { + auto* tensor_n = op->tensor.As(); + CHECK(tensor_n); + auto it = vectorizes.find(tensor_n->name); + if (it != vectorizes.end()) { + CHECK_LT(it->second.level, forloop_stack.size()); + forloop_stack[it->second.level]->set_vectorize_info(it->second); + CHECK(it->second.valid()); + } + } + + std::vector forloop_stack; +}; + +/** + * Mark the PolyFor as Unroll if is called Unroll in Stage. + */ +struct MarkUnrollMutator : public ir::IRMutator { + std::map /*level*/> unrolls; + + explicit MarkUnrollMutator(const std::map>& unrolls) : unrolls(unrolls) {} + + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + stack.push_back(node); + ir::IRMutator<>::Visit(op, expr); + stack.pop_back(); + } + + // each statement in ISL is bound to a Store node. + void Visit(const ir::Store* op, Expr* expr) override { + auto* tensor_n = op->tensor.As(); + CHECK(tensor_n); + auto it = unrolls.find(tensor_n->name); + if (it != unrolls.end()) { + for (int level : it->second) { + VLOG(1) << "Mark " << level << " Unrolled"; + CHECK_LT(level, stack.size()); + stack[level]->set_unrolled(); + } + } + } + + std::vector stack; +}; + +/** + * Mark the PolyFor as Parallel if is called Parallel in Stage. + */ +struct MarkParallelMutator : public ir::IRMutator { + std::map /*level*/> parallels; + + explicit MarkParallelMutator(const std::map>& parallels) : parallels(parallels) {} + + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + stack.push_back(node); + ir::IRMutator<>::Visit(op, expr); + stack.pop_back(); + } + + // each statement in ISL is bound to a Store node. + void Visit(const ir::Store* op, Expr* expr) override { + auto* tensor_n = op->tensor.As(); + CHECK(tensor_n); + auto it = parallels.find(tensor_n->name); + if (it != parallels.end()) { + for (int level : it->second) { + VLOG(1) << "Mark " << level << " Paralled"; + CHECK_LT(level, stack.size()); + stack[level]->set_parallel(); + } + } + } + + std::vector stack; +}; + +} // namespace detail +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_impl_test.cc b/paddle/cinn/lang/lower_impl_test.cc new file mode 100644 index 0000000000000..32b2c234093e0 --- /dev/null +++ b/paddle/cinn/lang/lower_impl_test.cc @@ -0,0 +1,320 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/lower_impl.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace lang { +namespace detail { + +#define CREATE_GNODE(k__) auto* n##k__ = graph->RetrieveNode(#k__); +#define ASSERT_LINKED(a__, b__) ASSERT_TRUE(n##a__->IsLinkedTo(n##b__)); + +TEST(CreateCompGraph, single_layer) { + Expr M(100); + Expr N(200); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + + LOG(INFO) << C->expr_fields().size(); + for (auto* e : C->expr_fields()) { + LOG(INFO) << "e: " << *e; + } + + auto stages = CreateStages({C}); + auto graph = CreateCompGraph({A, B, C}, stages); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + + /* generated graph + digraph G { + node_0[label="A"] + node_1[label="B"] + node_2[label="C"] + node_0->node_2 + node_1->node_2 + } // end G + */ + + CREATE_GNODE(A) + CREATE_GNODE(B) + CREATE_GNODE(C) + + ASSERT_TRUE(nA->IsLinkedTo(nC)); + ASSERT_TRUE(nB->IsLinkedTo(nC)); +} + +TEST(CreateCompGraph, multi_layers) { + Expr M(100); + Expr N(200); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // A->C + // B->C + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + + // C->D + // B->D + auto D = Compute( + {M, N}, [&](Expr i, Expr j) { return C(i, j) + B(i, j); }, "D"); + + // A->E + // B->E + // C->E + // D->E + auto E = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + + auto stages = CreateStages({C, D, E}); + auto graph = CreateCompGraph({A, B, E}, stages); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + + /* + digraph G { + node_0[label="A"] + node_1[label="B"] + node_3[label="C"] + node_4[label="D"] + node_2[label="E"] + node_0->node_2 + node_0->node_3 + node_1->node_2 + node_1->node_4 + node_1->node_3 + node_3->node_2 + node_3->node_4 + node_4->node_2 + } // end G + */ + + CREATE_GNODE(A) + CREATE_GNODE(B) + CREATE_GNODE(C) + CREATE_GNODE(D) + CREATE_GNODE(E) + + ASSERT_EQ(graph->num_nodes(), 5); + + ASSERT_LINKED(A, C) + ASSERT_LINKED(B, C) + + ASSERT_LINKED(C, D) + ASSERT_LINKED(B, D) + + ASSERT_LINKED(A, E) + ASSERT_LINKED(B, E) + ASSERT_LINKED(C, E) + ASSERT_LINKED(D, E) +} + +TEST(CreateCompGraph, multi_layers_with_extra_deps) { + Expr M(100); + Expr N(200); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // A->C + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + 1.f; }, "C"); + + // B->D + auto D = Compute( + {M, N}, [&](Expr i, Expr j) { return B(i, j) + 1.f; }, "D"); + + // A->E + auto E = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + 1.f; }, "E"); + + auto F = Compute( + {M, N}, [&](Expr i, Expr j) { return C(i, j) + D(i, j) + E(i, j); }, "F"); + + auto stages = CreateStages({C, D, E, F}); + // C->D + stages[D]->CtrlDepend(C); + // C->E + stages[E]->CtrlDepend(C); + + auto graph = CreateCompGraph({A, B, F}, stages); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + + /* + digraph G { + node_0[label="A"] + node_1[label="B"] + node_3[label="C"] + node_4[label="D"] + node_5[label="E"] + node_2[label="F"] + node_0->node_5 + node_0->node_3 + node_1->node_4 + node_3->node_2 + node_3->node_5 + node_3->node_4 + node_4->node_2 + node_5->node_2 + } // end G + */ + + CREATE_GNODE(A) + CREATE_GNODE(B) + CREATE_GNODE(C) + CREATE_GNODE(D) + CREATE_GNODE(E) + CREATE_GNODE(F) + + ASSERT_LINKED(B, D) + ASSERT_LINKED(A, C) + ASSERT_LINKED(A, E) + ASSERT_LINKED(C, E) + ASSERT_LINKED(C, F) + ASSERT_LINKED(C, D) + ASSERT_LINKED(D, F) +} + +TEST(CreateCompGraph, inline_compatible) { + Expr M(100); + Expr N(200); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // A->C + // B->C + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + + // C->D + // B->D + auto D = Compute( + {M, N}, [&](Expr i, Expr j) { return C(i, j) + B(i, j); }, "D"); + + // A->E + // B->E + // C->E + // D->E + auto E = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + + auto stages = CreateStages({C, D, E}); + stages[D]->ComputeInline(); + + auto graph = CreateCompGraph({A, B, E}, stages, true); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + + /* + digraph G { + node_0[label="A"] + node_1[label="B"] + node_3[label="C"] + node_2[label="E"] + node_0->node_2 + node_0->node_3 + node_1->node_2 + node_1->node_3 + node_3->node_2 + } // end G + */ + + CREATE_GNODE(A) + CREATE_GNODE(B) + CREATE_GNODE(C) + CREATE_GNODE(E) + + ASSERT_EQ(graph->num_nodes(), 4); + ASSERT_TRUE(nA->IsLinkedTo(nC)); + ASSERT_TRUE(nA->IsLinkedTo(nE)); + ASSERT_TRUE(nB->IsLinkedTo(nC)); + ASSERT_TRUE(nB->IsLinkedTo(nE)); + ASSERT_TRUE(nA->IsLinkedTo(nC)); + ASSERT_TRUE(nB->IsLinkedTo(nE)); +} + +TEST(CreateCompGraph, inline_compatible1) { + Expr M(100); + Expr N(200); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + // A->C + // B->C + auto C = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); + + // C->D + // B->D + auto D = Compute( + {M, N}, [&](Expr i, Expr j) { return C(i, j) + B(i, j); }, "D"); + + // A->E + // B->E + // C->E + // D->E + auto E = Compute( + {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j) + C(i, j) + D(i, j); }, "E"); + + auto stages = CreateStages({C, D, E}); + stages[C]->ComputeInline(); + + auto graph = CreateCompGraph({A, B, E}, stages, true); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + + /* + digraph G { + node_0[label="A"] + node_1[label="B"] + node_3[label="D"] + node_2[label="E"] + node_0->node_2 + node_1->node_2 + node_1->node_3 + node_3->node_2 + } // end G + */ + + CREATE_GNODE(A) + CREATE_GNODE(B) + CREATE_GNODE(D) + CREATE_GNODE(E) + + ASSERT_EQ(graph->num_nodes(), 4); + + ASSERT_TRUE(nA->IsLinkedTo(nE)); + ASSERT_TRUE(nD->IsLinkedTo(nE)); + ASSERT_TRUE(nB->IsLinkedTo(nE)); + ASSERT_TRUE(nB->IsLinkedTo(nD)); + ASSERT_TRUE(nD->IsLinkedTo(nE)); +} + +} // namespace detail +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/lower_test.cc b/paddle/cinn/lang/lower_test.cc new file mode 100755 index 0000000000000..a7f9ebbebe9e7 --- /dev/null +++ b/paddle/cinn/lang/lower_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/lower.h" + +#include + +#include + +#include "cinn/cinn.h" +#include "cinn/lang/buffer.h" +#include "cinn/lang/compute.h" +#include "cinn/lang/placeholder.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace lang { + +TEST(lower, basic) { + auto M = Expr(100); + auto N = Expr(15); + + Placeholder A("A", {Expr(M), Expr(N)}); + + auto B = Compute( + {M, N}, [=](Var i, Var j) -> Expr { return A(i, j) + 1.f; }, "B"); + + auto stages = CreateStages({B}); + + auto lower_funcs = Lower("cal_B", stages, {A, B}); + + LOG(INFO) << "lower_size " << lower_funcs; + +#define TEST_SOUTPUT(x, out) \ + std::cout << "\n" << x << std::endl; \ + EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out)); + + auto out = R"ROC( +{ + serial for (i, 0, 100) + { + serial for (j, 0, 15) + { + B[i, j] = (1.00000000f + A[i, j]) + } + } +} +)ROC"; + TEST_SOUTPUT(lower_funcs->body, out); +} + +TEST(lower, more_complex) { + Expr M(100); + Expr N(15); + Expr K(200); + + Placeholder A("A", {Expr(M), Expr(N)}); + Placeholder B("B", {Expr(N), Expr(K)}); + + auto C = Compute( + {M, N, K}, [=](Var i, Var j, Var k) -> Expr { return A(i, j) * B(j, k); }, "C"); + + auto stages = CreateStages({C}); + + auto lower_funcs = Lower("cal_C", stages, {A, B, C}); + + std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl; +} + +//! To support training, the dynamic shape support is vital. We test the corresponding lower ability here. +TEST(lower, dynamic_shape) { + Var B("B"); // B is like shape here. + Expr N(15); + Expr K(200); + + // Input is B * N, B is like batch. + Placeholder X("X", {Expr(B), Expr(N)}); + Placeholder W("W", {Expr(N), Expr(K)}); + + auto C = Compute( + {B, N, K}, [=](Var i, Var j, Var k) -> Expr { return X(i, j) * W(j, k); }, "C"); + + auto stages = CreateStages({C}); + auto lower_funcs = Lower("cal_C", stages, {X, W, C}); + + std::cout << "func:\n" << Expr(lower_funcs->self()) << std::endl; +} + +TEST(lower, lowered_call) { + Var B("B"); // B is like shape here. + Expr N(15); + + // Input is B * N, B is like batch. + Placeholder X("X", {Expr(B), Expr(N)}); + Placeholder Y("Y", {Expr(B), Expr(N)}); + + auto Z = Compute( + {B, N}, [&](Var i, Var j) { return X(i, j) + Y(i, j); }, "Z"); + + std::vector return_types({{Float(32), std::vector{{B, N}}, "C"}}); + auto tensors = CallLowered("lowered_fun0", {X, Y, Z}, return_types); + auto C = tensors[0]; + + auto stages = CreateStages({X, Y, Z, C}); + + LOG(INFO) << "call_op: " << C->operation->as()->call_expr; + + auto lower_func = Lower("fn", stages, {X, Y, Z, C}); +} + +// test the temp_buffers are all collected. +TEST(lower, temp_buffer_collects) { + Expr M(10); + + Placeholder A("A", {M}); + + auto B = Compute( + {M}, [&](Expr i) -> Expr { return A(i); }, "B"); // temp + auto C = Compute( + {M}, [&](Expr i) -> Expr { return B(i); }, "C"); // temp + auto D = Compute( + {M}, [&](Expr i) -> Expr { return C(i); }, "D"); // temp + auto output = Compute( + {M}, [&](Expr i) -> Expr { return D(i); }, "output"); + + ir::Module::Builder b("somemodule", common::DefaultHostTarget()); + + auto stages = CreateStages({B, C, D, output}); + + auto fn = Lower("fn", stages, {A, output}, {}, {}, &b); + + auto module = b.Build(); + + ASSERT_EQ(module.buffers().size(), 3UL); + + std::set detected_buffer_names({"_B", "_C", "_D"}); + + for (auto& buffer : module.buffers()) { + ASSERT_TRUE(detected_buffer_names.count(buffer->name)); + } +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/packed_func.cc b/paddle/cinn/lang/packed_func.cc new file mode 100644 index 0000000000000..47f6e777c2c2d --- /dev/null +++ b/paddle/cinn/lang/packed_func.cc @@ -0,0 +1,27 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/packed_func.h" + +namespace cinn { +namespace lang { + +Args::Args(cinn_value_t *values, int *type_codes, int len) { + for (int i = 0; i < len; i++) { + values_.emplace_back(values[i], type_codes[i]); + } +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/packed_func.h b/paddle/cinn/lang/packed_func.h new file mode 100644 index 0000000000000..eca3fe84cd9f6 --- /dev/null +++ b/paddle/cinn/lang/packed_func.h @@ -0,0 +1,128 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "cinn/common/cinn_value.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace lang { +using common::CINNValue; + +/** + * A single argument value to Function. + */ +using ArgValue = CINNValue; + +using RetValue = CINNValue; + +/** + * Arguments of the PackedFunc. + */ +class Args { + public: + Args() = default; + Args(cinn_value_t* values, int* type_codes, int len); + + //! Append a \p value of type code \p type_code. + void Append(const ArgValue& arg) { values_.push_back(arg); } + + //! Count of the arguments. + size_t size() const { return values_.size(); } + + //! if the arguments is empty + bool empty() const { return values_.empty(); } + + //! Get i-th element. + ArgValue& operator[](int i) { return values_[i]; } + const ArgValue& operator[](int i) const { return values_[i]; } + + common::CINNValuePack ToValuePack() const { return common::CINNValuePack(values_); } + + private: + std::vector values_; +}; + +namespace detail { + +template +struct for_each_dispatcher { + template + static void Run(const F& f, T&& value, Args&&... args) { + f(I, std::forward(value)); + for_each_dispatcher::Run(f, std::forward(args)...); + } +}; + +template +struct for_each_dispatcher { + static void Run(const F& f) {} +}; + +template +inline void for_each(const F& f, Args&&... args) { + for_each_dispatcher::Run(f, std::forward(args)...); +} + +struct FuncArgsSetter { + FuncArgsSetter(Args* args) : args_(args) {} // NOLINT + + template + void operator()(size_t I, T v) const { + args_->Append(ArgValue(v)); + } + + private: + mutable Args* args_{}; +}; + +} // namespace detail + +/** + * A function defininer with the arguments packed, all the PackedFuncs have the same signature. + */ +class PackedFunc { + public: + using body_t = std::function; + + PackedFunc() = default; + explicit PackedFunc(const std::string& name) : name_(name) {} + explicit PackedFunc(body_t body) : body_(body) {} + + template + inline RetValue operator()(Args_&&... args) const { + Args _args; + detail::FuncArgsSetter setter(&_args); + detail::for_each(setter, std::forward(args)...); + + RetValue ret_value; + body_(_args, &ret_value); + return ret_value; + } + + inline body_t body() const { return body_; } + + private: + std::string name_; + body_t body_; +}; + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/packed_func_test.cc b/paddle/cinn/lang/packed_func_test.cc new file mode 100644 index 0000000000000..e374c4655e3c7 --- /dev/null +++ b/paddle/cinn/lang/packed_func_test.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/packed_func.h" + +#include + +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace lang { + +TEST(Function, test) { + PackedFunc::body_t func_body = [](Args args, RetValue* ret) { + int a = args[0]; + int b = args[1]; + *ret = (a + b); + }; + PackedFunc func(func_body); + + int c = func(1, 2); + LOG(INFO) << "c " << c; +} + +TEST(Function, test1) { + PackedFunc::body_t body = [](Args args, RetValue* ret) { + auto* msg = static_cast(args[0]); + (*ret) = msg; + }; + + PackedFunc func(body); + const char* msg = "hello world"; + char* c = func(msg); + LOG(INFO) << static_cast(c); +} + +TEST(Function, Expr) { + PackedFunc::body_t body = [](Args args, RetValue* ret) { + Expr a = args[0]; + Expr b = args[1]; + + ASSERT_EQ(a->__ref_count__.val(), 4); + ASSERT_EQ(b->__ref_count__.val(), 4); + + Expr c = a + b; + (*ret) = CINNValue(c); + }; + + PackedFunc func(body); + + Expr a(1); + Expr b(2); + ASSERT_EQ(a->__ref_count__.val(), 1); + ASSERT_EQ(b->__ref_count__.val(), 1); + + Expr ret = func(a, b); + + ASSERT_EQ(utils::GetStreamCnt(ret), "(1 + 2)"); +} + +TEST(Function, ReturnMultiValue) { + PackedFunc::body_t body = [](Args args, RetValue* ret) { + int a = args[0]; + int b = args[1]; + int c = a + b; + int d = a - b; + + *ret = common::CINNValuePack{{common::CINNValue(c), common::CINNValue(d)}}; + }; + + PackedFunc func(body); + + common::CINNValuePack ret = func(1, 2); + int c = ret[0]; + int d = ret[1]; + + EXPECT_EQ(c, 3); + EXPECT_EQ(d, -1); +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/placeholder.cc b/paddle/cinn/lang/placeholder.cc new file mode 100644 index 0000000000000..c73476c2db021 --- /dev/null +++ b/paddle/cinn/lang/placeholder.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/placeholder.h" + +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace lang { + +using cinn::common::bfloat16; +using cinn::common::float16; + +ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name) { + std::vector expr_shape; + for (int s : shape) { + expr_shape.push_back(Expr(s)); + } + return CreatePlaceHolder(expr_shape, type, name); +} + +ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name) { + if (type.is_float(32)) { + return Placeholder(name, shape); + } else if (type.is_float(64)) { + return Placeholder(name, shape); + } else if (type.is_bfloat16()) { + return Placeholder(name, shape); + } else if (type.is_float16()) { + return Placeholder(name, shape); + } else if (type.is_int(8)) { + return Placeholder(name, shape); + } else if (type.is_int(16)) { + return Placeholder(name, shape); + } else if (type.is_int(32)) { + return Placeholder(name, shape); + } else if (type.is_int(64)) { + return Placeholder(name, shape); + } else if (type.is_uint(8)) { + return Placeholder(name, shape); + } else if (type.is_uint(16)) { + return Placeholder(name, shape); + } else if (type.is_uint(32)) { + return Placeholder(name, shape); + } else if (type.is_uint(64)) { + return Placeholder(name, shape); + } else if (type.is_bool()) { + return Placeholder(name, shape); + } + CINN_NOT_IMPLEMENTED +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/placeholder.h b/paddle/cinn/lang/placeholder.h new file mode 100644 index 0000000000000..dc945559cea23 --- /dev/null +++ b/paddle/cinn/lang/placeholder.h @@ -0,0 +1,115 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/operation.h" +#include "cinn/ir/tensor.h" +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace lang { + +using ir::Expr; + +/** + * Placeholder + * @tparam T + */ +template +class Placeholder { + public: + Placeholder(const std::string &name, const std::vector &shape); + Placeholder(const std::string &name, const std::vector &shape); + + //! Get a slice. + // @{ + Expr operator()(Expr a) const { return Call({a}); } + Expr operator()(Expr a, Expr b) const { return Call({a, b}); } + Expr operator()(Expr a, Expr b, Expr c) const { return Call({a, b, c}); } + Expr operator()(Expr a, Expr b, Expr c, Expr d) const { return Call({a, b, c, d}); } + Expr operator()(const std::vector &indices) const; + // @} + + Type type() const { return tensor_->type(); } + + operator ir::Tensor() { return tensor_; } + operator ir::Expr() { return Expr(tensor_); } + + ir::Tensor &operator->() { return tensor_; } + const ir::Tensor &operator->() const { return tensor_; } + + ir::Tensor tensor() const { return tensor_; } + + private: + Expr Call(const std::vector &indices) const; + + void Init(const std::string &name, const std::vector &shape); + + ir::Tensor tensor_; +}; + +template +Expr Placeholder::operator()(const std::vector &indices) const { + return tensor_(indices); +} + +template +Expr Placeholder::Call(const std::vector &indices) const { + return tensor_(indices); +} + +template +Placeholder::Placeholder(const std::string &name, const std::vector &shape) { + std::vector _shape; + for (int v : shape) _shape.push_back(Expr(v)); + Init(name, _shape); +} + +template +Placeholder::Placeholder(const std::string &name, const std::vector &shape) { + Init(name, shape); +} + +ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); + +ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); + +/// ------- details ------- +template +void Placeholder::Init(const std::string &name, const std::vector &shape) { + ir::Var buffer_ptr(Context::Global().NewName("buffer")); + buffer_ptr->set_type(type_of()); + + std::vector strides(shape.size(), Expr(1)); + Expr offset(0); + + std::vector axis; + for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i)); + + auto op = ir::PlaceholderOp::Make(name, shape, type_of()); + + tensor_ = ir::Tensor(name, type_of(), shape, shape, op, {}); + Buffer buffer(tensor_->type()); + tensor_->Bind(buffer); +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/lang/placeholder_test.cc b/paddle/cinn/lang/placeholder_test.cc new file mode 100644 index 0000000000000..5043b5280dfd6 --- /dev/null +++ b/paddle/cinn/lang/placeholder_test.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/lang/placeholder.h" + +#include + +#include "cinn/ir/ir_printer.h" + +namespace cinn { +namespace lang { + +TEST(placeholder, basic) { + Expr M(100); + Expr N(20); + + Placeholder x("x", {M, N}); + + ir::Var i("i"); + ir::Var j("j"); + + auto slice = x(i, j); + LOG(INFO) << "slice " << slice; +} + +TEST(placeholder, dynamic_shape) { + Var B("B", Int(32)); + Expr N(20); + + Placeholder x("x", {B, N}); + + Var i("i"), j("j"); + auto slice = x(i, j); +} + +} // namespace lang +} // namespace cinn diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt new file mode 100755 index 0000000000000..54407db0af697 --- /dev/null +++ b/paddle/cinn/optim/CMakeLists.txt @@ -0,0 +1,50 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS + remove_nested_block.cc + replace_call_with_expr.cc + ir_copy.cc + ir_replace.cc + replace_var_with_expr.cc + tensor_write_tell.cc + ir_simplify.cc + optimize.cc + vectorize_loops.cc + unroll_loops.cc + transform_polyfor_to_for.cc + eliminate_broadcast_in_forloop.cc + fold_cinn_call_arguments.cc + call_arg_list_to_pod_value.cc + insert_debug_log_callee.cc + lower_function_call_bind_vars.cc + extern_call_process.cc + map_extern_call.cc + compute_inline_expand.cc + buffer_assign.cc + replace_const_param_to_integer.cc + cast_simplify.cc + if_simplify.cc + lower_intrin.cc + cast_bool_to_int8.cc + collect_undefined_vars.cc + var_mod_simplify.cc + remove_schedule_block.cc + ) + +if (WITH_CUDA) + gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) +endif() + + +cc_test(test_remove_nested_block SRCS remove_nested_block_test.cc DEPS cinncore) +cc_test(test_ir_copy SRCS ir_copy_test.cc DEPS cinncore) +cc_test(test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore) +cc_test(test_replace_call_with_expr SRCS replace_call_with_expr_test.cc DEPS cinncore) +cc_test(test_vectorize_loops SRCS vectorize_loops_test.cc DEPS cinncore ARGS ${global_test_args}) +cc_test(test_transform_polyfor_to_for SRCS transform_polyfor_to_for_test.cc DEPS cinncore ARGS ${global_test_args}) +cc_test(test_optimize SRCS optimize_test.cc DEPS cinncore) +cc_test(test_cache_read_write_replace SRCS cache_read_write_replace_test.cc DEPS cinncore) +cc_test(test_cast_simplify SRCS cast_simplify_test.cc DEPS cinncore) +cc_test(test_if_simplify SRCS if_simplify_test.cc DEPS cinncore) +cc_test(test_remove_schedule_block SRCS remove_schedule_block_test.cc DEPS cinncore) +cc_test(test_unroll_loops SRCS unroll_loops_test.cc DEPS cinncore) diff --git a/paddle/cinn/optim/buffer_assign.cc b/paddle/cinn/optim/buffer_assign.cc new file mode 100644 index 0000000000000..0b59feb339237 --- /dev/null +++ b/paddle/cinn/optim/buffer_assign.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/buffer_assign.h" + +#include "cinn/common/union_find.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/lower_impl.h" +#include "cinn/optim/ir_replace.h" + +namespace cinn { +namespace optim { + +namespace { + +struct BufferUFNode : public common::UnionFindNode { + BufferUFNode(const std::string& x) : tensor_name(x) {} + + const char* type_info() const override { return __type_info__; } + + std::string tensor_name; + static const char* __type_info__; +}; + +const char* BufferUFNode::__type_info__ = "BufferUFNode"; + +struct IRReplaceTensorMutator : ir::IRMutator<> { + const std::map& tensor_map; + IRReplaceTensorMutator(const std::map& tensor_map) : tensor_map(tensor_map) {} + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::_Tensor_* op, Expr* expr) override { + auto it = tensor_map.find(op->name); + if (it != tensor_map.end()) { + *expr = Expr(it->second); + } + } +}; + +} // namespace + +std::map InitialAssignBuffer(Expr* expr, + poly::StageMap stages, + const std::map& all_tensor_map, + const common::Graph* comp_graph, + const std::set& temp_tensor_names) { + // The tensor map helps to reserve only one tensor instance for a tensor(called the same name). + std::map buffer_updated_tensor; + + for (auto& item : all_tensor_map) { + if (stages[item.second]->inlined()) continue; + buffer_updated_tensor[item.second->name] = item.second; + } + + // union-find to cluster the tensors with the same buffer. + common::UnionFind union_find; + + // unify all the tensor occurance with a global one, e.g. there are multiple tensor B exists in the expression, + // replace them with a shared one. + ir::CollectIRNodes(*expr, [&](const Expr* x) -> bool { + auto* t = x->as_tensor(); + if (t && !stages[t]->inlined()) { + Reference(x) = Expr(all_tensor_map.at(t->name)); + } + return false; + }); + + std::map uf_map; + for (auto& item : all_tensor_map) { + auto* n = union_find.AddNode(new BufferUFNode(item.second->name)); + uf_map[item.second->name] = n->safe_as(); + } + + for (auto& item : buffer_updated_tensor) { + auto* cur_n = uf_map[item.first]; + for (auto& other : stages[item.second]->meta.tensors_to_share_buffer_with) { + // we might intialize the buffer in args. + auto* other_n = uf_map[other]; + if (!other_n) continue; + + VLOG(3) << "share buffer between " << item.first << " " << other_n->tensor_name; + cur_n->Union(other_n); + } + } + + // determine which tensor to have the initial buffer, and will share across the cluster, we take a topological order + // of the computational graph, and find out which tensor comes first in a cluster. + + auto _topo_order_topo_edges_ = comp_graph->topological_order(); + auto& topo_order = std::get<0>(_topo_order_topo_edges_); + auto& topo_edges = std::get<1>(_topo_order_topo_edges_); + for (common::GraphNode* n : topo_order) { + auto nn = n->safe_as(); + CHECK(nn); + { + auto it = uf_map.find(nn->tensor->name); + CHECK(it != uf_map.end()); + auto& cluster_info = std::get<0>(it->second->GetRoot())->cluster_info; + if (cluster_info.empty()) { // buffer owner(a tensor) of this cluster not set yet. + cluster_info = nn->tensor->name; + } + } + } + + // Get a center of the cluster, it will consider the following rules + // 1. Prefer a tensor arg than a temp tensor. + auto cluster_get_center_tensor = [&](const std::vector& cluster) { + ir::Tensor some_tensor; + // try to find a node that is a tensor_arg, allocate buffer for it, and make others share buffer with it. + for (auto* n : cluster) { + auto* node = n->safe_as(); + bool is_temp = temp_tensor_names.count(node->tensor_name); + if (!is_temp) return all_tensor_map.at(node->tensor_name); + if (all_tensor_map.at(node->tensor_name)->buffer.defined()) { + return all_tensor_map.at(node->tensor_name); + } + some_tensor = all_tensor_map.at(node->tensor_name); + } + return some_tensor; + }; + + for (auto& cluster : union_find.GetClusters()) { + auto root_tensor = cluster_get_center_tensor(cluster); + if (!root_tensor->buffer.defined() && !root_tensor->type().is_void()) { + root_tensor->WithBuffer(); + } + + for (auto* n : cluster) { + auto& tensor = all_tensor_map.at(n->safe_as()->tensor_name); + if (tensor != root_tensor) { + auto keep_shape = root_tensor->buffer->shape; + Reference(&tensor)->Bind(root_tensor->buffer); + root_tensor->buffer->shape = keep_shape; + Reference(&tensor)->buffer->shape = keep_shape; + VLOG(3) << "keep_shape is : " << utils::GetStreamCnt(keep_shape[0]); + } + } + } + + return buffer_updated_tensor; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/buffer_assign.h b/paddle/cinn/optim/buffer_assign.h new file mode 100644 index 0000000000000..69464607ad7de --- /dev/null +++ b/paddle/cinn/optim/buffer_assign.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace optim { + +/** + * Assign buffer for tensors those are not marked as compute_inline. + * @param expr + * @param stages The stage map. + */ +std::map InitialAssignBuffer(Expr* expr, + poly::StageMap stages, + const std::map& all_tensor_map, + const common::Graph* comp_graph, + const std::set& temp_tensor_names); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/cache_read_write_replace_test.cc b/paddle/cinn/optim/cache_read_write_replace_test.cc new file mode 100755 index 0000000000000..eda11ac0ccc3d --- /dev/null +++ b/paddle/cinn/optim/cache_read_write_replace_test.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +TEST(CacheReadWriteReplace, basic) { + Context::Global().ResetNameId(); + Expr M(100); + Expr N(20); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + auto C = Compute( + {M, N}, [&](Expr i, Expr j) -> Expr { return A(i, j) + B(i, j); }, "C"); + + auto stages = CreateStages({C}); + + // AA cache + std::vector readers{C}; + auto AA = stages[A]->CacheRead("shared", readers, stages); + auto CC = stages[C]->CacheWrite("local", stages, C); + + auto fn = Lower("fn", stages, {A, B, C}, {}, {AA, CC}); + + LOG(INFO) << "fn:\n" << Expr(fn); + + auto target = R"ROC( +function fn (_A, _B, _C) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + A_read_cache[i, j] = A[i, j] + } + } + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + C_write_cache[i, j] = (A_read_cache[i, j] + B[i, j]) + } + } + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + C[i, j] = C_write_cache[i, j] + } + } +} + )ROC"; + + ASSERT_EQ(utils::Trim(target), utils::GetStreamCnt(fn)); +} + +TEST(CacheReadWriteReplace, cache_write) { + Context::Global().ResetNameId(); + + Expr M(100); + Expr N(100); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + auto C = Compute( + {M, N}, [=](Expr i, Expr j) { return A(i, j) + 1.f; }, "C"); + + auto C0 = Compute( + {M, N}, [=](Expr i, Expr j) { return C(i, j) + 1.f; }, "C0"); + auto C1 = Compute( + {M, N}, [=](Expr i, Expr j) { return C0(i, j) + 1.f; }, "C1"); + + auto stages = CreateStages({A, B, C, C0, C1}); + stages[C]->ComputeInline(); + stages[C0]->ComputeInline(); + + auto Co = stages[C1]->CacheWrite("shared", stages, C1); + + auto fn = Lower("fn", stages, {A, B, Co}, {}, {C, C0, C1}); + LOG(INFO) << "\n" << fn; + + auto target_source = R"ROC( +function fn (_A, _B, _C1_write_cache) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 100) + { + C1_write_cache[i, j] = (3.00000000f + A[i, j]) + } + } + serial for (i, 0, 100) + { + serial for (j, 0, 100) + { + C1[i, j] = C1_write_cache[i, j] + } + } +} +)ROC"; + + ASSERT_EQ(utils::Trim(target_source), utils::GetStreamCnt(fn)); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/call_arg_list_to_pod_value.cc b/paddle/cinn/optim/call_arg_list_to_pod_value.cc new file mode 100644 index 0000000000000..afdddbb566a1b --- /dev/null +++ b/paddle/cinn/optim/call_arg_list_to_pod_value.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/call_arg_list_to_pod_value.h" + +#include +#include +#include + +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/runtime/intrinsic.h" + +namespace cinn { +namespace optim { + +namespace { + +struct CallArgListToPodValueMutator : ir::IRMutator<> { + void operator()(Expr* e) { ir::IRMutator<>::Visit(e, e); } + + private: + void Visit(const ir::Call* op, Expr* expr) override { + if (op->is_cinn_call()) { + auto _oprs_args_ = pack_arg_exprs(op); // NOLINT + auto& oprs = std::get<0>(_oprs_args_); + auto& args = std::get<1>(_oprs_args_); + + Var pod_array_var(Context::Global().NewName("_pod_arr"), + type_of().with_lanes(op->total_args_count())); + + // Declare pod_array. + oprs.push_back(ir::Let::Make(pod_array_var, Expr())); + oprs.push_back(ir::intrinsics::ArgsConstruct::Make(pod_array_var, args)); + + auto new_call = ir::Call::Make(Void(), + op->name, + {pod_array_var, common::make_const(Int(32), args.size())}, + {}, + ir::CallType::CINN, + op->func, + op->value_index); + + oprs.push_back(new_call); + + *expr = ir::Block::Make(oprs); + } + } + + std::tuple /*oprs*/, std::vector /*args*/> pack_arg_exprs(const ir::Call* op) { + std::vector exprs; + std::vector args; + + auto pack_arg = [&](const Expr& arg) { + Var pod_var(Context::Global().NewName("_pod_val_"), type_of()); + + // declare the array. + exprs.push_back(ir::Let::Make(pod_var, Expr())); + + auto pod_val_addr_expr = ir::intrinsics::GetAddr::Make(pod_var); + + Expr cast; + if (arg.As()) { + cast = runtime::IntrinsicCall( + Void(), runtime::intrinsic::buffer_p_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + + } else if (arg.type() == type_of()) { + cast = runtime::IntrinsicCall( + Void(), runtime::intrinsic::float_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + } else if (arg.type() == type_of()) { + cast = runtime::IntrinsicCall( + Void(), runtime::intrinsic::int32_to_cinn_pod_value_repr, {arg}, {pod_val_addr_expr}); + } else { + CINN_NOT_IMPLEMENTED + } + + exprs.push_back(cast); + args.push_back(pod_val_addr_expr); + }; + + for (auto& arg : op->read_args) { + pack_arg(arg); + } + for (auto& arg : op->write_args) { + pack_arg(arg); + } + + return std::make_tuple(exprs, args); + } +}; + +} // namespace + +void CallArgListToPodValue(Expr* e) { CallArgListToPodValueMutator()(e); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/call_arg_list_to_pod_value.h b/paddle/cinn/optim/call_arg_list_to_pod_value.h new file mode 100644 index 0000000000000..2c568177ff75f --- /dev/null +++ b/paddle/cinn/optim/call_arg_list_to_pod_value.h @@ -0,0 +1,28 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +/** + * \file Transform the CINN Call node's args to cinn_pod_value_t array. + */ + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +void CallArgListToPodValue(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/cast_bool_to_int8.cc b/paddle/cinn/optim/cast_bool_to_int8.cc new file mode 100644 index 0000000000000..86584aba5072c --- /dev/null +++ b/paddle/cinn/optim/cast_bool_to_int8.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/cast_bool_to_int8.h" + +#include + +#include "cinn/ir/ir_mutator.h" + +namespace cinn::optim { + +namespace { + +struct Mutator : public ir::IRMutator<> { + using ir::IRMutator<>::Visit; + + void Visit(const ir::Store* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto value = node->value; + if (op->type().is_bool() && op->value->type().is_bool()) { + value = ir::Cast::Make(Int(8), value); + *expr = ir::Store::Make(node->tensor, value, node->indices); + } + } +}; + +} // namespace + +void CastBoolToInt8(Expr* e, Target target) { + if (target.arch == Target::Arch::X86) { + Mutator mutator; + mutator.Visit(e, e); + } +} +} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_bool_to_int8.h b/paddle/cinn/optim/cast_bool_to_int8.h new file mode 100644 index 0000000000000..c7770840167e5 --- /dev/null +++ b/paddle/cinn/optim/cast_bool_to_int8.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn::optim { + +/** + * Cast the expr from bool to Int8 type for llvm codegen, currently used in cpu. + * + * e.g. + * + * The expression: + * a = b + * + * to + * + * a = int8(b) + */ +void CastBoolToInt8(Expr* e, Target target); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_simplify.cc b/paddle/cinn/optim/cast_simplify.cc new file mode 100644 index 0000000000000..eb88dbc3d29a4 --- /dev/null +++ b/paddle/cinn/optim/cast_simplify.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/cast_simplify.h" + +#include "cinn/ir/ir_mutator.h" + +namespace cinn::optim { + +using cinn::common::bfloat16; +using cinn::common::float16; + +namespace { + +template +CastType NormCastValue(T value) { + if (type_of().is_uint() || type_of().is_uint()) { + // not support uint + return static_cast(value); + } + + if (std::isinf(value)) { + return std::numeric_limits::infinity(); + } else if (std::isnan(value)) { + return std::numeric_limits::signaling_NaN(); + } else if (value >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } else if (value <= static_cast(std::numeric_limits::lowest())) { + return std::numeric_limits::lowest(); + } + return static_cast(value); +} + +struct Mutator : ir::IRMutator<> { + using ir::IRMutator<>::Visit; + + void Visit(const ir::Cast* op, Expr* expr) { + auto* node = expr->As(); + + Visit(&node->v(), &node->v()); + + if (op->type() == op->v().type()) { + *expr = op->v(); + return; + } + +#define __CAST_TO_TYPE(type__) \ + if (auto* i = op->v().As()) { \ + *expr = Expr(static_cast(i->value)); \ + } else if (auto* f = op->v().As()) { \ + *expr = Expr(static_cast(NormCastValue(f->value))); \ + } else if (auto* u = op->v().As()) { \ + *expr = Expr(static_cast(u->value)); \ + } else { \ + CINN_NOT_IMPLEMENTED \ + } + + if (op->v().is_constant()) { + if (op->type() == type_of()) { + __CAST_TO_TYPE(int8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(int64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint8_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint16_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(float) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(double) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(bool) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint32_t) + } else if (op->type() == type_of()) { + __CAST_TO_TYPE(uint64_t) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(bfloat16) + } else if (op->type() == type_of()) { + // Cannot simplify!!! pass + __CAST_TO_TYPE(float16) + } else { + CINN_NOT_IMPLEMENTED + } + } +#undef __CAST_TO_TYPE + } +}; + +} // namespace + +void CastSimplify(Expr* e) { + Mutator mutator; + mutator.Visit(e, e); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_simplify.h b/paddle/cinn/optim/cast_simplify.h new file mode 100644 index 0000000000000..7a3e1abf1ff70 --- /dev/null +++ b/paddle/cinn/optim/cast_simplify.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/ir/ir.h" + +namespace cinn::optim { + +/** + * Simplify the Cast nodes. + * + * There are several patterns: + * 1. the source and target type are the same, drop the Cast node + * 2. for intermediate numbers, just replace the Cast node with a Node of the target type + */ +void CastSimplify(Expr* e); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/cast_simplify_test.cc b/paddle/cinn/optim/cast_simplify_test.cc new file mode 100644 index 0000000000000..2aad9b6789556 --- /dev/null +++ b/paddle/cinn/optim/cast_simplify_test.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/cast_simplify.h" + +#include + +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" + +namespace cinn::optim { + +TEST(CastSimplify, same_type) { + Var n("n"); + Expr a = ir::Cast::Make(Int(32), n); + LOG(INFO) << n->type(); + LOG(INFO) << a; + CastSimplify(&a); + ASSERT_EQ(utils::GetStreamCnt(a), "n"); +} + +TEST(CastSimplify, Imm_int) { + Expr a = ir::Cast::Make(Int(64), Expr(1)); + Expr c = ir::Cast::Make(Int(32), a); + LOG(INFO) << c; + CastSimplify(&c); + LOG(INFO) << c; + ASSERT_EQ(utils::GetStreamCnt(c), "1"); + ASSERT_EQ(c.type(), Int(32)); +} + +TEST(CastSimplify, Imm_double) { + Expr a = ir::Cast::Make(Float(64), Expr(2.33)); + Expr c = ir::Cast::Make(Int(32), a); + LOG(INFO) << c; + CastSimplify(&c); + LOG(INFO) << c; + ASSERT_EQ(utils::GetStreamCnt(c), "2"); + ASSERT_EQ(c.type(), Int(32)); +} + +TEST(CastSimplify, Imm_uint) { + Expr a = ir::Cast::Make(UInt(64), Expr(1)); + Expr c = ir::Cast::Make(UInt(32), a); + LOG(INFO) << c; + CastSimplify(&c); + LOG(INFO) << c; + ASSERT_EQ(utils::GetStreamCnt(c), "1"); + ASSERT_EQ(c.type(), UInt(32)); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/collect_undefined_vars.cc b/paddle/cinn/optim/collect_undefined_vars.cc new file mode 100644 index 0000000000000..244342bad2cb4 --- /dev/null +++ b/paddle/cinn/optim/collect_undefined_vars.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/collect_undefined_vars.h" + +#include + +#include "cinn/ir/ir_mutator.h" + +namespace cinn::optim { + +namespace { +struct Mutator : public ir::IRMutator<> { + using ir::IRMutator<>::Visit; + std::vector undefined_vars; + std::set defined_vars; + std::set used_vars; + + void CollectVarDef(const std::string& var) { + CHECK(!defined_vars.count(var)) << "var " << var << " has been defined, please check"; + CHECK(!used_vars.count(var)) << "var " << var << " is wrongly used before definition"; + defined_vars.insert(var); + } + + void ClearVar(const std::string& var) { + defined_vars.erase(var); + used_vars.erase(var); + } + + void CollectVarUse(const std::string& var) { + used_vars.insert(var); + if (defined_vars.count(var) == 0) { + undefined_vars.push_back(var); + } + } + + void Visit(const ir::Let* op, Expr* expr) final { + Expr symbol = op->symbol; + auto var = symbol.as_var_ref(); + CHECK(var.defined()); + CollectVarDef(var->name); + auto* node = expr->As(); + Visit(&node->body, &node->body); + } + + void Visit(const ir::For* op, Expr* expr) final { + CollectVarDef(op->loop_var->name); + auto* node = expr->As(); + Visit(&node->min, &node->min); + Visit(&node->extent, &node->extent); + Visit(&node->body, &node->body); + ClearVar(op->loop_var->name); + } + + void Visit(const ir::Load* op, Expr* expr) final { + auto tensor = op->tensor.as_tensor_ref(); + CollectVarUse(tensor->name); + auto* node = expr->As(); + for (auto& idx : node->indices) Visit(&idx, &idx); + } + + void Visit(const ir::Store* op, Expr* expr) final { + auto tensor = op->tensor.as_tensor_ref(); + CollectVarUse(tensor->name); + auto* node = expr->As(); + for (auto& idx : node->indices) Visit(&idx, &idx); + Visit(&node->value, &node->value); + } + + void Visit(const ir::_Var_* op, Expr* expr) final { + CollectVarUse(op->name); + auto* node = expr->As(); + if (node->lower_bound.defined()) { + Visit(&node->lower_bound, &node->lower_bound); + } + if (node->upper_bound.defined()) { + Visit(&node->upper_bound, &node->upper_bound); + } + } + + void Visit(const ir::Reduce* op, Expr* expr) final { + for (auto& axis : op->reduce_axis) { + CollectVarDef(axis->name); + } + auto* node = expr->As(); + if (node->init.defined()) Visit(&node->init, &node->init); + Visit(&node->body, &node->body); + } +}; +} // namespace + +std::vector CollectUndefinedVars(Expr* e) { + Mutator mutator; + mutator.Visit(e, e); + return mutator.undefined_vars; +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/collect_undefined_vars.h b/paddle/cinn/optim/collect_undefined_vars.h new file mode 100644 index 0000000000000..25b4de3a2d4d5 --- /dev/null +++ b/paddle/cinn/optim/collect_undefined_vars.h @@ -0,0 +1,36 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" +namespace cinn::optim { + +/** + * Collect undefined vars in the scope. + * + * e.g. + * + * The expression: + * for i + * for j + * a[i, j] = b[i, j] + * + * here a, b are vars without definition + */ +std::vector CollectUndefinedVars(Expr* e); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/compute_inline_expand.cc b/paddle/cinn/optim/compute_inline_expand.cc new file mode 100644 index 0000000000000..9e110706aab57 --- /dev/null +++ b/paddle/cinn/optim/compute_inline_expand.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/compute_inline_expand.h" + +#include +#include + +#include "cinn/common/graph_utils.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace optim { + +namespace { + +/* + * Replace a tensor(marked as compute_inline) to the expanded expression. + */ +struct TensorInlineExpandMutator : public ir::IRMutator<> { + const std::string &tensor_name_; + std::map *all_tensor_map_; + poly::StageMap stages_; + bool inline_code{false}; + bool temp_buffer{false}; + bool memory_local{false}; + std::unordered_map> resized_buffer_cache; + std::vector tensor_names; + std::vector> replace_var; + std::map var_to_extent; + + TensorInlineExpandMutator(const std::string &tensor_name, + std::map *all_tensor_map, + poly::StageMap stages) + : tensor_name_(tensor_name), all_tensor_map_(all_tensor_map), stages_(stages) {} + + void operator()(Expr *expr) { + ir::IRMutator<>::Visit(expr, expr); + for (int i = 0; i < tensor_names.size(); i++) { + for (auto &var : replace_var[i]) { + } + } + } + + void Visit(const ir::_Var_ *expr, Expr *op) override { + if (inline_code && temp_buffer) { + if (utils::Startswith(expr->name, "blockIdx") || (utils::Startswith(expr->name, "threadIdx") && memory_local)) { + *op = ir::Expr(0); + } + } + } + + void Visit(const ir::_Tensor_ *op, Expr *expr) override { + if (inline_code && utils::Endswith(op->name, "_write_cache") && + (*all_tensor_map_).at(op->name)->buffer->memory_type == ir::MemoryType::Heap) { + auto no_cache_name = op->name.substr(0, op->name.size() - 12); + VLOG(2) << "no_cache_name: " << no_cache_name; + CHECK(all_tensor_map_->count(no_cache_name)); + *expr = (*all_tensor_map_)[no_cache_name]; + } + } + + void Visit(const ir::For *op, Expr *expr) override { + CHECK(op->extent.is_constant()); + int cons_extent = (int)op->extent.get_constant(); + var_to_extent[op->loop_var->name] = op->extent; + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::PolyFor *op, Expr *expr) override { + auto extent = op->ExtractExtent(); + var_to_extent[op->iterator->name] = extent; + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + auto *node = expr->As(); + auto *tensor = node->tensor.as_tensor(); + if (tensor && tensor->name == tensor_name_) { + *expr = tensor->inline_expanded(op->indices); + inline_code = true; + ir::IRMutator<>::Visit(expr, expr); + inline_code = false; + } else if (inline_code && tensor->buffer.defined()) { + bool is_heap = (*all_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::Heap; + if (utils::Endswith(tensor->buffer->name, "_write_cache") && is_heap) { + // temp fix: cache_write will change the tensor to the cache tensor wrongly + auto no_cache_name = tensor->buffer->name.substr(1, tensor->buffer->name.size() - 13); + if (all_tensor_map_->count(no_cache_name)) { + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + } else { + auto *tensor = node->tensor.as_tensor(); + CHECK(tensor); + // fix computeAt case + auto shapes = tensor->shape; + CHECK_EQ(shapes.size(), node->indices.size()); + for (int i = 0; i < shapes.size(); i++) { + if (common::is_zero(shapes[i] - 1)) { + node->indices[i] = Expr(0); + } + } + } + } else if (utils::Endswith(tensor->buffer->name, "_write_cache") || + utils::Endswith(tensor->buffer->name, "_read_cache") || + utils::Endswith(tensor->buffer->name, "_temp_buffer")) { +#ifdef CINN_WITH_CUDA + auto axis_names = stages_[tensor]->axis_names(); + auto compute_ats = stages_[tensor]->GetComputeAts(); + if (compute_ats.size() == 1) { + int level_tmp; + for (auto &i : compute_ats) { + level_tmp = i.second.level; + } + std::vector replace_vars; + for (int j = 0; j <= level_tmp; j++) { + if (var_to_extent.count(axis_names[j]) == 0) continue; + replace_vars.push_back(Var(var_to_extent[axis_names[j]], axis_names[j])); + } + replace_var.push_back(replace_vars); + tensor_names.push_back(tensor->buffer->name); + } +#endif + bool keep_buffer = temp_buffer; + temp_buffer = true; + bool keep_memory_local = memory_local; + if ((*all_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal) { + memory_local = true; + } + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + for (int i = 0; i < node->indices.size(); i++) { + auto temp = optim::IRCopy(node->indices[i]); + ir::IRMutator<>::Visit(&temp, &temp); + node->indices[i] = temp; + } + temp_buffer = keep_buffer; + memory_local = keep_memory_local; + } else { + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + for (int i = 0; i < node->indices.size(); i++) { + auto temp = optim::IRCopy(node->indices[i]); + ir::IRMutator<>::Visit(&temp, &temp); + node->indices[i] = temp; + } + } + } else { + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + for (int i = 0; i < node->indices.size(); i++) { + auto temp = optim::IRCopy(node->indices[i]); + ir::IRMutator<>::Visit(&temp, &temp); + node->indices[i] = temp; + } + } + } +}; + +struct SSANode : public common::GraphNode { + std::string id_; + + explicit SSANode(const std::string &id) : id_(id) {} + + std::string id() const override { return id_; } + + const char *type_info() const override { return __type_info__; } + + static constexpr char *__type_info__ = "optim::SSANode"; +}; + +// TODO(Superjomn) the graph here is not a SSA now, it is flattern for the ir::CollectIRNodes method collects all the +// tensors recursively, so it can not reserve the level information, fix it. +struct SSABuilder : public ir::IRMutator<> { + common::Graph graph; + + SSABuilder &operator()(Expr *expr) { + ir::IRMutator<>::Visit(expr, expr); + return *this; + } + + void Visit(const ir::Store *op, Expr *expr) override { + auto *node = expr->As(); + + auto *cur_graph_node = graph.RetrieveNode(node->tensor.as_tensor()->name); + if (!cur_graph_node) { + cur_graph_node = graph.RegisterNode(node->tensor.as_tensor()->name, new SSANode(node->tensor.as_tensor()->name)); + } + + auto deps_tensor_names = node->tensor.as_tensor()->GetDependTensorNames(); + for (auto &t : deps_tensor_names) { + auto *n = graph.RetrieveNode(t); + if (!n) { + n = graph.RegisterNode(t, new SSANode(t)); + } + n->Controls(cur_graph_node); + } + } +}; + +} // namespace + +void ComputeInlineExpand(Expr *expr, poly::StageMap stages, std::map *all_tensor_map) { + // the inline tensors contained in the expression. + auto inline_tensors = + ir::CollectIRNodes(*expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); }); + + // keep inline expand if any inline tensor exists + // NOTE This is a naive method to greedily expand the inline tensors until none exists, a better way is to create a + // SSA graph and expand the inline tensors in the reversed dependency order. + // TODO(Superjomn) Use the SSA graph to improve this. + while (!inline_tensors.empty()) { + for (const auto &t : inline_tensors) { + auto *tensor = t.as_tensor(); + TensorInlineExpandMutator(tensor->name, all_tensor_map, stages)(expr); + } + + inline_tensors = ir::CollectLoadTensors( + *expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); }); + } +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/compute_inline_expand.h b/paddle/cinn/optim/compute_inline_expand.h new file mode 100644 index 0000000000000..9fa5baf682eb8 --- /dev/null +++ b/paddle/cinn/optim/compute_inline_expand.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/cinn.h" + +namespace cinn { +namespace optim { + +/** + * Recursive expand the inlined tensors. + * @param expr the expression to modify. + * @param tensor_name name of the tensor to expand inline. + * @param memo a memo to avoid duplicate expand. + */ +void ComputeInlineExpand(Expr* expr, poly::StageMap stages, std::map* all_tensor_map); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc b/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc new file mode 100644 index 0000000000000..64ee0ba7a5664 --- /dev/null +++ b/paddle/cinn/optim/eliminate_broadcast_in_forloop.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/eliminate_broadcast_in_forloop.h" + +#include +#include + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/optim/ir_replace.h" + +namespace cinn { +namespace optim { + +namespace detail { + +struct EliminateBroadcastInForloop : public ir::IRMutator { + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::Store* op, Expr* expr) { + // TODO(Superjom) Support single one level of forloop. + if (forloop_stack.size() < 2) return; + + auto* node = expr->As(); + + auto broadcasts = ir::CollectIRNodes(node->value, [&](const Expr* expr) { return expr->As(); }); + std::vector let_exprs; + + Var tmp; + Expr let_expr; + + Var cur_level_loop_var = forloop_stack.back()->As() ? forloop_stack.back()->As()->loop_var + : forloop_stack.back()->As()->iterator; + for (Expr broadcast : broadcasts) { + if (ContainsLoopVar(broadcast, cur_level_loop_var)) continue; + VLOG(4) << "eliminating " << broadcast; + std::tie(let_expr, tmp) = CreateTmpLet(broadcast); + let_exprs.push_back(let_expr); + + optim::IrReplace(expr, broadcast, tmp); + } + + // insert the let expressions to the outer forloop. + + Expr* outer_forloop = forloop_stack[forloop_stack.size() - 2]; + + auto& outer_forloop_body = + outer_forloop->As() ? outer_forloop->As()->body : outer_forloop->As()->body; + + auto* outer_forloop_body_block = outer_forloop_body.As(); + if (outer_forloop_body_block) { + outer_forloop_body_block->stmts.insert( + std::begin(outer_forloop_body_block->stmts), let_exprs.begin(), let_exprs.end()); + + } else { + let_exprs.push_back(outer_forloop_body); + outer_forloop_body = ir::Block::Make(let_exprs); + } + } + + bool ContainsLoopVar(Expr expr, Var loop_var) { + return !ir::CollectIRNodes(expr, [&](const Expr* e) -> bool { + return e->As() && e->As()->name == loop_var->name; + }).empty(); + } + + std::tuple CreateTmpLet(Expr body) { + Var tmp(Context::Global().NewName("tmp"), body.type()); + + Expr let_expr = ir::Let::Make(tmp, body); + + return std::make_tuple(let_expr, tmp); + } + + void Visit(const ir::For* op, Expr* expr) { + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(op, expr); + forloop_stack.pop_back(); + } + + void Visit(const ir::PolyFor* op, Expr* expr) { + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(op, expr); + forloop_stack.pop_back(); + } + + std::vector forloop_stack; +}; + +} // namespace detail + +void EliminateBroadcastInForloop(Expr* expr) { + detail::EliminateBroadcastInForloop mutator; + mutator(expr); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/eliminate_broadcast_in_forloop.h b/paddle/cinn/optim/eliminate_broadcast_in_forloop.h new file mode 100644 index 0000000000000..95f1a9a4063a6 --- /dev/null +++ b/paddle/cinn/optim/eliminate_broadcast_in_forloop.h @@ -0,0 +1,24 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +void EliminateBroadcastInForloop(Expr* expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/extern_call_process.cc b/paddle/cinn/optim/extern_call_process.cc new file mode 100644 index 0000000000000..0f3f62c243b68 --- /dev/null +++ b/paddle/cinn/optim/extern_call_process.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/extern_call_process.h" + +#include "cinn/ir/ir_mutator.h" + +namespace cinn { +namespace optim { + +namespace { + +struct ExternCallMultiOutputShallowStoreMutator : public ir::IRMutator<> { + void operator()(Expr* e) { ir::IRMutator<>::Visit(e, e); } + + private: + void Visit(const ir::Store* op, Expr* expr) override { + auto* call = op->value.As(); + if (call && call->is_extern_call() && !call->write_args.empty()) { + *expr = op->value; + } + } +}; + +} // namespace + +void ExternCallMultiOutputShallowStore(Expr* e) { ExternCallMultiOutputShallowStoreMutator()(e); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/extern_call_process.h b/paddle/cinn/optim/extern_call_process.h new file mode 100644 index 0000000000000..6f371a1134d7f --- /dev/null +++ b/paddle/cinn/optim/extern_call_process.h @@ -0,0 +1,27 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +void ExternCallMultiOutputShallowStore(Expr* e); + +void ExternCallRemoveTupleGetStatements(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/fold_cinn_call_arguments.cc b/paddle/cinn/optim/fold_cinn_call_arguments.cc new file mode 100644 index 0000000000000..e09e7ede205fb --- /dev/null +++ b/paddle/cinn/optim/fold_cinn_call_arguments.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/fold_cinn_call_arguments.h" + +#include +#include + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { + +namespace { + +/** + * Fold the arguments of the Call nodes marked as CINN(calls an LoweredFunc). + */ +struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> { + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Block* op, Expr* expr) override { + auto* node = expr->As(); + for (auto it = node->stmts.begin(); it != node->stmts.end();) { + if (it->As()) { + auto* call = it->As()->value.As(); + if (call && call->is_cinn_call()) { + // remove the duplicate calls. + std::string key = utils::GetStreamCnt(Expr(call)); + if (visited_call_.count(key)) { + it = node->stmts.erase(it); + continue; + } + + ir::IRMutator<>::Visit(&(*it), &(*it)); + visited_call_.insert(key); + continue; + } + } + + ir::IRMutator<>::Visit(&(*it), &(*it)); + ++it; + } + } + void Visit(const ir::Store* op, Expr* expr) override { + auto* node = expr->As(); + if (node->value.As()) { + auto* call = node->value.As(); + switch (call->call_type) { + case ir::CallType::CINN: + MutateCall(call); + *expr = node->value; + break; + case ir::CallType::Intrinsic: + break; + case ir::CallType::Extern: + break; + default: + CINN_NOT_IMPLEMENTED + } + } + } + + void MutateCall(ir::Call* call) { + if (call->call_type == ir::CallType::Extern) return; + + std::vector read_args; + std::vector write_args; + for (auto& arg : call->read_args) { + if (arg.as_tensor()) { + CHECK(arg.as_tensor()->buffer.defined()) << "arg tensor [" << arg.as_tensor()->name << "] not has buffer"; + read_args.push_back(arg.as_tensor()->buffer); + } else { + read_args.push_back(arg); + } + } + + for (auto& arg : call->write_args) { + if (arg.as_tensor()) { + write_args.push_back(arg.as_tensor()->buffer); + } else { + write_args.push_back(arg); + } + } + + call->read_args = read_args; + call->write_args = write_args; + } + + private: + // To avoid the same call triggered duplicately. + std::unordered_set visited_call_; +}; + +} // namespace + +void FoldCINNCallArguments(Expr* expr) { FoldCINNCallArgumentsMutator()(expr); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/fold_cinn_call_arguments.h b/paddle/cinn/optim/fold_cinn_call_arguments.h new file mode 100644 index 0000000000000..8c15438792077 --- /dev/null +++ b/paddle/cinn/optim/fold_cinn_call_arguments.h @@ -0,0 +1,46 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * \brief Rewrite the Call Nodes marked type as CINN, pack their arguments into `void*, int` so that they can trigger a + * `LoweredFunc`. + * + * For example, input the IR + * \code + * Call(some_lowered_func, a:cinn_buffer_t*, b:cinn_buffer_t*, c:cinn_buffer_t*) + * \endcode + * + * This pass will rewrite it to + * \code + * cinn_pod_value_t a_(a); + * cinn_pod_value_t b_(b); + * cinn_pod_value_t c_(c); + * + * cinn_args_construct(packed_args, a_, b_, c_); + * Call(some_lowered_func, packed_args, 3); // 3 is the number of arguments + * \endcode + */ +void FoldCINNCallArguments(Expr* expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/if_simplify.cc b/paddle/cinn/optim/if_simplify.cc new file mode 100644 index 0000000000000..0d999dac84795 --- /dev/null +++ b/paddle/cinn/optim/if_simplify.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/if_simplify.h" + +#include "cinn/ir/ir_mutator.h" + +namespace cinn::optim { + +namespace { + +struct Mutator : public ir::IRMutator<> { + using ir::IRMutator<>::Visit; + + void Visit(const ir::IfThenElse* op, Expr* expr) { + auto* condition_int = op->condition.As(); + auto* condition_uint = op->condition.As(); + int64_t value; + if (condition_int || condition_uint) { + if (condition_int) { + value = condition_int->value; + } else { + value = condition_uint->value; + } + if (value) { + *expr = op->true_case; + } else { + if (op->false_case.defined()) { + *expr = op->false_case; + } else { + // null condition + *expr = ir::Block::Make({}); + } + } + } + } +}; + +} // namespace + +void IfSimplify(Expr* e) { + Mutator mutator; + mutator.Visit(e, e); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/if_simplify.h b/paddle/cinn/optim/if_simplify.h new file mode 100644 index 0000000000000..2e4fa1426ee59 --- /dev/null +++ b/paddle/cinn/optim/if_simplify.h @@ -0,0 +1,22 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn::optim { + +void IfSimplify(Expr* e); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/if_simplify_test.cc b/paddle/cinn/optim/if_simplify_test.cc new file mode 100644 index 0000000000000..1221a58b805cc --- /dev/null +++ b/paddle/cinn/optim/if_simplify_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/if_simplify.h" + +#include + +#include + +#include "cinn/ir/ir_printer.h" + +namespace cinn::optim { + +TEST(IfSimplify, if_true) { + Var n("n"); + auto e = ir::IfThenElse::Make(Expr(1) /*true*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); + + LOG(INFO) << "\n" << e; + + IfSimplify(&e); + + LOG(INFO) << e; + + ASSERT_EQ(utils::GetStreamCnt(e), "int32 n = 1"); +} + +TEST(IfSimplify, if_false) { + Var n("n"); + auto e = ir::IfThenElse::Make(Expr(0) /*false*/, ir::Let::Make(n, Expr(1)), ir::Let::Make(n, Expr(2))); + + LOG(INFO) << "\n" << e; + + IfSimplify(&e); + + LOG(INFO) << e; + + ASSERT_EQ(utils::GetStreamCnt(e), "int32 n = 2"); +} + +TEST(IfSimplify, if_else_empty) { + Var n("n"); + auto e = ir::IfThenElse::Make(Expr(0) /*false*/, ir::Let::Make(n, Expr(1))); + + LOG(INFO) << "\n" << e; + + IfSimplify(&e); + + LOG(INFO) << e; + + std::string target = utils::Trim(R"ROC( +{ + +} +)ROC"); + + ASSERT_EQ(utils::GetStreamCnt(e), target); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/insert_debug_log_callee.cc b/paddle/cinn/optim/insert_debug_log_callee.cc new file mode 100644 index 0000000000000..6c7988b6016c8 --- /dev/null +++ b/paddle/cinn/optim/insert_debug_log_callee.cc @@ -0,0 +1,275 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/insert_debug_log_callee.h" + +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { +using cinn::utils::StringFormat; + +namespace { + +struct StoreDebugInfoBuilder : public ir::IRVisitor { + std::tuple> operator()(const Expr *e) { + ir::IRVisitor::Visit(e); + return std::make_tuple(format_.str(), args_); + } + + private: +#define _BINARY_OP(Op__, repr__) \ + void Visit(const ir::Op__ *x) override { \ + format_ << "("; \ + ir::IRVisitor::Visit(&x->a()); \ + format_ << " " << #repr__ << " "; \ + ir::IRVisitor::Visit(&x->b()); \ + format_ << ")"; \ + } + _BINARY_OP(Add, +); + _BINARY_OP(Mul, *); + _BINARY_OP(Div, /); + _BINARY_OP(Sub, -); + _BINARY_OP(Mod, %); + _BINARY_OP(LT, <); + _BINARY_OP(LE, <=); + _BINARY_OP(GT, >); + _BINARY_OP(GE, >=); +#undef _BINARY_OP + + void Visit(const ir::Load *x) override { + format_ << type_specifier(x->type()); + args_.push_back(Expr(&Reference(x))); + } + + public: + void Visit(const Expr *x) override { IRVisitor::Visit(x); } + void Visit(const ir::IntImm *x) override { + format_ << type_specifier(x->type()); + args_.push_back(&Reference(x)); + } + void Visit(const ir::UIntImm *x) override { + format_ << type_specifier(x->type()); + args_.push_back(&Reference(x)); + } + void Visit(const ir::FloatImm *x) override { + format_ << type_specifier(x->type()); + args_.push_back(&Reference(x)); + } + void Visit(const ir::StringImm *x) override {} + void Visit(const ir::EQ *x) override {} + void Visit(const ir::_Var_ *x) override {} + void Visit(const ir::NE *x) override {} + void Visit(const ir::And *x) override {} + void Visit(const ir::Or *x) override {} + void Visit(const ir::Min *x) override {} + void Visit(const ir::Max *x) override {} + void Visit(const ir::Minus *x) override {} + void Visit(const ir::Not *x) override {} + void Visit(const ir::Cast *x) override {} + void Visit(const ir::For *x) override {} + void Visit(const ir::PolyFor *x) override {} + void Visit(const ir::Select *x) override {} + void Visit(const ir::IfThenElse *x) override {} + void Visit(const ir::Block *x) override {} + void Visit(const ir::Call *x) override {} + void Visit(const ir::Store *x) override { + format_ << x->tensor.as_tensor()->name << "[] = "; + Visit(&x->value); + LOG(INFO) << "store value " << x->value; + } + void Visit(const ir::Alloc *x) override {} + void Visit(const ir::Free *x) override {} + void Visit(const ir::_Buffer_ *x) override {} + void Visit(const ir::_Tensor_ *x) override {} + void Visit(const ir::_LoweredFunc_ *x) override {} + void Visit(const ir::_Module_ *x) override {} + void Visit(const ir::Let *x) override {} + void Visit(const ir::Reduce *x) override {} + void Visit(const ir::Ramp *x) override {} + void Visit(const ir::Broadcast *x) override {} + void Visit(const ir::FracOp *x) override {} + void Visit(const ir::Product *x) override {} + void Visit(const ir::Sum *x) override {} + + private: + std::string type_specifier(const Type &type) { + if (type.is_float()) return "%f"; + if (type.is_int()) return "%d"; + CINN_NOT_IMPLEMENTED + return ""; + } + + private: + std::stringstream format_; + std::vector args_; + bool in_load_{false}; +}; + +struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { + void operator()(Expr *e) { ir::IRMutator<>::Visit(e, e); } + + void Visit(const ir::_LoweredFunc_ *op, Expr *expr) { + auto *node = expr->As(); + auto *body_block = node->body.As(); + CHECK(body_block); + + auto msg = StringFormat("running : %s", GetDebugString(*expr).c_str()); + auto debug_node = CreateDebugStatement(msg); + + ir::IRMutator<>::Visit(&node->body, &node->body); + + auto deal_with_exprs = [&](std::vector *exprs) { // deal with op->argument_preapre_exprs + std::vector new_stmts; + for (auto &expr : *exprs) { + auto msg = StringFormat("running : %s", GetDebugString(expr).c_str()); + new_stmts.push_back(CreateDebugStatement(msg)); + new_stmts.push_back(expr); + } + *exprs = new_stmts; + }; + + deal_with_exprs(&node->alloc_output_buffer_exprs); + deal_with_exprs(&node->dealloc_output_buffer_exprs); + deal_with_exprs(&node->buffer_data_cast_exprs); + deal_with_exprs(&node->argument_prepare_exprs); + + body_block->stmts.insert(body_block->stmts.begin(), debug_node); + } + + void Visit(const ir::Block *op, Expr *expr) { + auto *node = expr->As(); + std::vector new_stmts; + for (auto &e : op->stmts) { + if (!IsDebugInfoNode(e)) { + std::string msg; + if (!e.As()) { + msg = StringFormat("running: %s", GetDebugString(e).c_str()); + auto debug_info_node = CreateDebugStatement(msg); + new_stmts.push_back(debug_info_node); + } else { + auto _msg_args_ = StoreDebugInfo(e); + auto &msg = std::get<0>(_msg_args_); + auto &args = std::get<1>(_msg_args_); + auto debug_info_node = CreateDebugStatement("running: " + msg, std::move(args)); + new_stmts.push_back(debug_info_node); + } + } + + ir::IRMutator<>::Visit(&e, &Reference(&e)); + + new_stmts.push_back(e); + + if (!IsDebugInfoNode(e) && e.As()) { + auto _msg_args_ = StoreDebugInfo(e); + auto &msg = std::get<0>(_msg_args_); + auto &args = std::get<1>(_msg_args_); + auto debug_info_node = CreateDebugStatement(msg, std::move(args)); + new_stmts.push_back(debug_info_node); + + { // detailed debug + auto _format_args_ = StoreDebugInfoBuilder()(&e); + auto &format = std::get<0>(_format_args_); + auto &args = std::get<1>(_format_args_); + new_stmts.push_back(CreateDebugStatement(format, std::move(args))); + } + } + } + + node->stmts = new_stmts; + } + + std::string GetDebugString(const Expr &e) { + std::stringstream ss; + switch (e.node_type()) { + case ir::IrNodeTy::Block: + ss << ""; + break; + case ir::IrNodeTy::For: { + auto *node = e.As(); + ss << "loop_var << " in [" << node->min << ", " << node->extent << ")>"; + break; + } + case ir::IrNodeTy::PolyFor: { + auto *node = e.As(); + ss << "iterator << " in [" << node->init << ", " << node->ExtractExtent() << ")" + << " with condition: " << node->condition << ">"; + break; + } + case ir::IrNodeTy::_LoweredFunc_: { + auto *node = e.As(); + ss << "name << ">"; + break; + } + case ir::IrNodeTy::Call: { + auto *node = e.As(); + if (node->name == runtime::intrinsic::debug_log_repr) { + return ""; + } + ss << e; + break; + } + default: + ss << "NodeTy " << e->node_type() << ": " << e; + break; + } + + return ss.str(); + } + + std::tuple> StoreDebugInfo(const Expr &e) { + const auto *node = e.As(); + + std::stringstream format_ss; + + { // destination + format_ss << node->tensor.as_tensor()->name << "["; + for (auto &index : node->indices) format_ss << "%d "; + format_ss << "] = %f"; + } + + format_ss << ", "; + + std::vector val_reprs; + for (auto &index : node->indices) val_reprs.push_back(index); + val_reprs.push_back(ir::Load::Make(node->tensor, node->indices)); + + return std::make_tuple(format_ss.str(), val_reprs); + } + + inline bool IsDebugInfoNode(const Expr &e) { + return e.As() && e.As()->name == runtime::intrinsic::debug_log_repr; + } + + Expr CreateDebugStatement(const std::string &msg, std::vector &&args = {}) { + args.insert(args.begin(), Expr(msg)); + return ir::Call::Make( + Void(), runtime::intrinsic::debug_log_repr, args, {}, ir::CallType ::Intrinsic, ir::FunctionRef(), 0); + } +}; + +} // namespace + +void InsertDebugLogCallee(Expr *e) { InsertDebugLogCalleeMutator()(e); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/insert_debug_log_callee.h b/paddle/cinn/optim/insert_debug_log_callee.h new file mode 100644 index 0000000000000..470c909d36ce1 --- /dev/null +++ b/paddle/cinn/optim/insert_debug_log_callee.h @@ -0,0 +1,27 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +void InsertDebugLogCallee(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_copy.cc b/paddle/cinn/optim/ir_copy.cc new file mode 100644 index 0000000000000..0603a2998def7 --- /dev/null +++ b/paddle/cinn/optim/ir_copy.cc @@ -0,0 +1,480 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/ir_copy.h" + +#include +#include +#include +#include + +#include "cinn/common/common.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/ir/module.h" + +namespace cinn { +namespace optim { +using namespace ir; // NOLINT + +struct IRCopyVisitor : public ir::IRVisitorBase { + // Use maps to unify all the copied tensors and buffers. + std::map tensor_map; + std::map buffer_map; + + Expr Visit(const Expr* op) override { return IRVisitorBase::Visit(op); } + + protected: + // The methods of ir nodes follows the order defined in node.h + + Expr Visit(const ir::IntImm* op) override { return Expr(make_shared(op->type(), op->value)); } + Expr Visit(const ir::UIntImm* op) override { return Expr(make_shared(op->type(), op->value)); } + Expr Visit(const ir::FloatImm* op) override { return Expr(make_shared(op->type(), op->value)); } + Expr Visit(const ir::StringImm* op) override { return Expr(common::make_shared(op->value)); } + + Expr Visit(const ir::Cast* op) override { + auto v = Visit(&op->v()); + return Cast::Make(op->type(), v); + } + + Expr Visit(const Select* op) override { + auto condition = Visit(&op->condition); + auto true_value = Visit(&op->true_value); + auto false_value = Visit(&op->false_value); + return Select::Make(condition, true_value, false_value); + } + + Expr Visit(const IfThenElse* op) override { + auto condition = Visit(&op->condition); + auto true_case = Visit(&op->true_case); + Expr false_case; + if (op->false_case.defined()) false_case = Visit(&op->false_case); + return IfThenElse::Make(condition, true_case, false_case); + } + + Expr Visit(const Block* op) override { + std::vector stmts; + for (auto& s : op->stmts) { + stmts.push_back(Visit(&s)); + } + return Block::Make(stmts); + } + + Expr Visit(const Call* op) override { + auto read_args = Visit(op->read_args); + auto write_args = Visit(op->write_args); + return Call::Make(op->type(), op->name, read_args, write_args, op->call_type, FunctionRef(), 0, op->attrs); + } + + Expr Visit(const _Var_* op) override { + auto* n = make_shared<_Var_>(); + + n->name = op->name; + n->is_reduce_axis = op->is_reduce_axis; + n->set_type(op->type()); + + if (op->lower_bound.defined()) { + n->lower_bound = Visit(&op->lower_bound); + } + if (op->upper_bound.defined()) { + n->upper_bound = Visit(&op->upper_bound); + } + + return Expr(n); + } + + Expr Visit(const Load* op) override { + auto tensor = Visit(&op->tensor); + std::vector indices; + for (auto& idx : op->indices) { + indices.push_back(Visit(&idx)); + } + return Load::Make(tensor, indices); + } + + Expr Visit(const Store* op) override { + auto tensor = Visit(&op->tensor); + auto value = Visit(&op->value); + std::vector indices; + for (auto& idx : op->indices) indices.push_back(Visit(&idx)); + + return Store::Make(tensor, value, indices); + } + + Expr Visit(const Alloc* op) override { + auto extents = Visit(op->extents); + Expr condition; + Expr body; + if (op->condition.defined()) condition = Visit(&op->condition); + if (op->body.defined()) body = Visit(&op->body); + + return Alloc::Make(op->destination, op->type(), extents, condition, body); + } + + Expr Visit(const Free* op) override { return Free::Make(op->destination); } + + Expr Visit(const _Buffer_* op) override { + if (buffer_map.count(op->name)) { + return buffer_map[op->name]; + } + + auto shape = Visit(op->shape); + auto strides = Visit(op->strides); + auto name = op->name; + auto scope = op->scope; + int data_alignment = op->data_alignment; + auto elem_offset = Visit(&op->elem_offset); + int offset_factor = op->offset_factor; + Target target = op->target; + + auto new_node = _Buffer_::Make(name, shape); + new_node->strides = strides; + new_node->dtype = op->dtype; // copy data element's type. + new_node->name = name; + new_node->scope = scope; + new_node->data_alignment = data_alignment; + new_node->elem_offset = elem_offset; + new_node->offset_factor = offset_factor; + new_node->target = target; + new_node->memory_type = op->memory_type; + new_node->set_type(op->type()); + op->CopyMeta(new_node.As()); + + buffer_map[op->name] = new_node->self(); + + return Expr(ir::Buffer(new_node)); + } + + Expr Visit(const _Tensor_* op) override { + if (tensor_map.count(op->name)) { + return tensor_map[op->name]; + } + + auto shape = Visit(op->shape); + auto domain = Visit(op->domain); + auto buffer_expr = Expr(op->buffer); + // TODO(Superjomn) copy the operation. + auto operaion = op->operation; + auto name = op->name; + auto tensor = make_shared<_Tensor_>(); + + if (buffer_expr.defined()) { + auto buffer = Visit(&buffer_expr); + tensor->buffer = buffer.as_buffer_ref(); + } + tensor->domain = domain; + tensor->shape = shape; + tensor->reduce_axis = op->reduce_axis; + tensor->operation = operaion; + tensor->name = name; + tensor->set_type(op->type()); + tensor->axis_ = op->axis_; + + tensor_map[tensor->name] = tensor; + + return tensor; + } + + Expr Visit(const For* op) override { + auto extent = Visit(&op->extent); + auto min = Visit(&op->min); + auto body = Visit(&op->body); + + return ir::For::Make( + op->loop_var, min, extent, op->for_type(), op->device_api, body, op->vectorize_info(), op->bind_info()); + } + + Expr Visit(const ir::PolyFor* op) override { + auto init = Visit(&op->init); + auto condition = Visit(&op->condition); + auto inc = Visit(&op->inc); + auto body = Visit(&op->body); + auto expr = PolyFor::Make(op->iterator, + init, + condition, + inc, + op->for_type(), + op->device_api, + body, + op->vectorize_info(), + op->bind_info()); + return expr; + } + + Expr Visit(const ir::_Module_* op) override { + std::vector buffers; + std::vector functions; + std::vector submodules; + + for (auto& expr : op->buffers) { + buffers.push_back(Visit(&expr)); + } + + for (auto& expr : op->functions) { + functions.push_back(Visit(&expr)); + } + + for (auto& expr : op->submodules) { + submodules.push_back(Visit(&expr)); + } + + auto res = ir::_Module_::Make(op->name, op->target); + res->buffers = buffers; + res->functions = functions; + res->submodules = submodules; + + return Expr(res); + } + + Expr Visit(const _LoweredFunc_* op) override { + auto func = make_shared<_LoweredFunc_>(); + + func->name = op->name; + func->args = op->args; + func->body = Visit(&op->body); + func->temp_bufs = op->temp_bufs; + + func->device_api = op->device_api; + + func->cuda_axis_info = op->cuda_axis_info; + + std::vector alloc_output_buffer_exprs; + std::vector dealloc_output_buffer_exprs; + std::vector buffer_data_cast_exprs; + std::vector argument_prepare_exprs; + +#define COPY_ADD_FIELD(field__) \ + for (auto& expr : op->field__) { \ + field__.push_back(Visit(&expr)); \ + } \ + func->field__ = std::move(field__); + + COPY_ADD_FIELD(alloc_output_buffer_exprs); + COPY_ADD_FIELD(dealloc_output_buffer_exprs); + COPY_ADD_FIELD(buffer_data_cast_exprs); + COPY_ADD_FIELD(argument_prepare_exprs); + + return func; + } + + Expr Visit(const Let* op) override { + auto value = Visit(&op->symbol); + auto body = Visit(&op->body); + + return Let::Make(value, body); + } + + Expr Visit(const Reduce* op) override { + auto init = Visit(&op->init); + auto body = Visit(&op->body); + std::vector reduce_axis(op->reduce_axis.begin(), op->reduce_axis.end()); + return Reduce::Make(op->reduce_type, init, body, reduce_axis); + } + + Expr Visit(const Ramp* op) override { + auto base = Visit(&op->base); + auto stride = Visit(&op->stride); + int lanes = op->lanes; + return Ramp::Make(base, stride, lanes); + } + + Expr Visit(const Broadcast* op) override { + auto value = Visit(&op->value); + int lanes = op->lanes; + CHECK(value.defined()); + CHECK(value.type().valid()); + + auto* n = make_shared(); + n->value = value; + n->lanes = lanes; + return Expr(n); + } + + Expr Visit(const FracOp* op) override { + auto a = Visit(&op->a()); + auto b = Visit(&op->b()); + CHECK(a.defined()); + CHECK(b.defined()); + + auto* n = make_shared(); + n->a() = a; + n->b() = b; + return Expr(n); + } + + Expr Visit(const Product* op) override { + std::vector operands; + for (auto& v : op->operands()) { + operands.push_back(Visit(&v)); + } + return Product::Make(operands); + } + + Expr Visit(const Sum* op) override { + std::vector operands; + for (auto& v : op->operands()) { + operands.push_back(Visit(&v)); + } + return Sum::Make(operands); + } + + Expr Visit(const ir::PrimitiveNode* op) override { + std::vector> arguments; + for (auto& args : op->arguments) { + arguments.push_back(Visit(args)); + } + + auto n = common::make_shared(); + n->name = op->name; + n->attrs = op->attrs; // attrs are PODs + n->arguments = arguments; + return Expr(n); + } + + Expr Visit(const ir::_BufferRange_* op) { + std::vector ranges; + for (auto& range_var : op->ranges) { + auto* var = range_var.As<_Var_>(); + ranges.push_back(Visit(var)); + } + return ir::_BufferRange_::Make(Visit(&op->buffer), ranges); + } + + Expr Visit(const ir::ScheduleBlock* op) { + std::vector iter_vars; + for (auto iter_var : op->iter_vars) { + auto* var = iter_var.As<_Var_>(); + CHECK(var); + iter_vars.push_back(Visit(var)); + } + std::vector read_buffers; + for (auto buffer_range : op->read_buffers) { + read_buffers.push_back(Visit(&buffer_range)); + } + std::vector write_buffers; + for (auto buffer_range : op->write_buffers) { + write_buffers.push_back(Visit(&buffer_range)); + } + Expr res = ir::ScheduleBlock::Make(iter_vars, read_buffers, write_buffers, op->name, Visit(&op->body)); + res.As()->attrs = op->attrs; + return res; + } + + Expr Visit(const ir::ScheduleBlockRealize* op) { + std::vector iter_values; + for (auto iter_value : op->iter_values) { + iter_values.push_back(Visit(&iter_value)); + } + return ir::ScheduleBlockRealize::Make(iter_values, Visit(&op->schedule_block)); + } + +#define __(x__) Expr Visit(const ir::intrinsics::x__* op); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + + Expr Visit(const ir::IntrinsicOp* op) override { + switch (op->getKind()) { +#define __(x__) \ + case ir::IntrinsicKind::k##x__: \ + return Visit(llvm::dyn_cast(op)); + INTRINSIC_KIND_FOR_EACH(__) +#undef __ + } + } + +#define OP_BINARY_HANDLE(op__) \ + Expr Visit(const ir::op__* op) override { \ + auto a = IRVisitorBase::Visit(&op->a()); \ + auto b = IRVisitorBase::Visit(&op->b()); \ + return op__::Make(a, b); \ + } + NODETY_BINARY_OP_FOR_EACH(OP_BINARY_HANDLE) +#undef OP_BINARY_HANDLE + +#define OP_UNARY_HANDLE(op__) \ + Expr Visit(const op__* op) override { \ + auto v = IRVisitorBase::Visit(&op->v()); \ + return op__::Make(v); \ + } + NODETY_UNARY_OP_FOR_EACH(OP_UNARY_HANDLE) +#undef OP_UNARY_HANDLE + + std::vector Visit(const std::vector& vs) { + std::vector copied; + for (auto& e : vs) { + copied.push_back(Visit(&e)); + } + return copied; + } +}; + +Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferGetDataHandle* op) { + return intrinsics::BufferGetDataHandle::Make(Visit(&op->buffer)); +} +Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferGetDataConstHandle* op) { + return intrinsics::BufferGetDataConstHandle::Make(Visit(&op->buffer)); +} +Expr IRCopyVisitor::Visit(const ir::intrinsics::PodValueToX* op) { + return intrinsics::PodValueToX::Make(Visit(&op->pod_value_ptr), op->GetOutputType(0)); +} +Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferCreate* op) { + return intrinsics::BufferCreate::Make(Visit(&op->buffer)); +} +Expr IRCopyVisitor::Visit(const ir::intrinsics::GetAddr* op) { return intrinsics::GetAddr::Make(Visit(&op->data)); } +Expr IRCopyVisitor::Visit(const ir::intrinsics::ArgsConstruct* op) { + llvm::SmallVector args; + for (auto& arg : op->args) { + args.push_back(Visit(&arg)); + } + return intrinsics::ArgsConstruct::Make(op->var, args); +} +Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) { + return intrinsics::BuiltinIntrin::Make(op->name, op->args, op->id, op->arg_nums, op->type()); +} + +Expr IRCopy(Expr x) { + IRCopyVisitor visitor; + auto copied = visitor.Visit(&x); + return copied; +} + +std::vector IRCopy(const std::vector& x) { + std::vector res; + for (auto& i : x) { + res.emplace_back(IRCopy(i)); + } + return res; +} + +ir::ModuleExpr IRCopy(const ir::ModuleExpr& x) { return ir::ModuleExpr(IRCopy(x.GetExprs())); } + +ir::LoweredFunc IRCopy(const ir::LoweredFunc& x) { + ir::Expr copy_func_expr = IRCopy(static_cast(x)); + ir::_LoweredFunc_* copy_func_ptr = copy_func_expr.As(); + return ir::LoweredFunc(copy_func_ptr); +} + +// TODO(zhhsplendid): make IRCopy of std::vector a template function +std::vector IRCopy(const std::vector& x) { + std::vector res; + for (const auto& i : x) { + res.emplace_back(IRCopy(i)); + } + return res; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_copy.h b/paddle/cinn/optim/ir_copy.h new file mode 100644 index 0000000000000..38baef7067f11 --- /dev/null +++ b/paddle/cinn/optim/ir_copy.h @@ -0,0 +1,43 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/lowered_func.h" + +namespace cinn { + +namespace ir { +class ModuleExpr; +} // namespace ir + +namespace optim { + +//! Shallow copy an expression. +Expr IRCopy(Expr x); + +std::vector IRCopy(const std::vector& x); + +ir::ModuleExpr IRCopy(const ir::ModuleExpr& x); + +ir::LoweredFunc IRCopy(const ir::LoweredFunc& x); + +std::vector IRCopy(const std::vector& x); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_copy_test.cc b/paddle/cinn/optim/ir_copy_test.cc new file mode 100644 index 0000000000000..ee592fda58aed --- /dev/null +++ b/paddle/cinn/optim/ir_copy_test.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/ir_copy.h" + +#include + +#include "cinn/ir/ir_printer.h" + +namespace cinn { +namespace optim { + +TEST(IrCopy, basic) { + Expr a(1.f); + auto aa = IRCopy(a); + LOG(INFO) << "aa " << aa; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_replace.cc b/paddle/cinn/optim/ir_replace.cc new file mode 100755 index 0000000000000..9ebf0c7271680 --- /dev/null +++ b/paddle/cinn/optim/ir_replace.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/ir_replace.h" + +#include + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { +using utils::GetStreamCnt; + +namespace { + +struct IrReplaceMutator : ir::IRMutator { + std::set valid_nodetys{{ir::IrNodeTy::Broadcast, ir::IrNodeTy::_Var_}}; + + IrReplaceMutator(ir::Expr from, Expr to) : from_(from), to_(to), from_repr_(GetStreamCnt(from)) { + CHECK(valid_nodetys.count(from->node_type())) << "Not valid node type got " << from->node_type(); + } + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::_Var_* op, Expr* expr) override { + if (op->node_type() == from_->node_type() && from_repr_ == GetStreamCnt(*expr)) { + *expr = optim::IRCopy(to_); + } + } + + void Visit(const ir::Broadcast* op, Expr* expr) override { + if (op->node_type() == from_->node_type() && from_repr_ == GetStreamCnt(*expr)) { + *expr = optim::IRCopy(to_); + } + } + + std::string from_repr_; + ir::Expr from_; + Expr to_; +}; + +} // namespace + +void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to) { + CHECK(expr); + IrReplaceMutator(from, to)(expr); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_replace.h b/paddle/cinn/optim/ir_replace.h new file mode 100644 index 0000000000000..c6982056693e4 --- /dev/null +++ b/paddle/cinn/optim/ir_replace.h @@ -0,0 +1,27 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +//! Replace the variable \p v to expression \p e in expression \p expr. +void IrReplace(ir::Expr* expr, ir::Expr from, ir::Expr to); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_simplify.cc b/paddle/cinn/optim/ir_simplify.cc new file mode 100644 index 0000000000000..0ed3d92c93aeb --- /dev/null +++ b/paddle/cinn/optim/ir_simplify.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/ir_simplify.h" + +#include +#include +#include + +#include +#include + +#include "cinn/common/arithmatic.h" +#include "cinn/common/cas.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/ir/tensor.h" +#include "cinn/optim/cast_simplify.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { +using namespace ir; // NOLINT +using common::ExprToGinacConverter; +using utils::GetStreamCnt; +using utils::Replace; + +namespace { + +//! Simplify some sub-expression in the `expr`. Due to the simplify strategy just fit several kinds of IR noedes, we +//! partition the original expression to several sub-expression those supported by simplify, and process each of them. +void PartialSimplify(Expr* expr, const absl::flat_hash_map& var_intervals = {}) { + *expr = common::AutoSimplify(*expr, var_intervals); +} + +//! Simplify the expression but Load. +struct SimplifyButStoreLoadMutator : public ir::IRMutator { + common::cas_intervals_t& var_intervals; + explicit SimplifyButStoreLoadMutator(common::cas_intervals_t& var_intervals) : var_intervals(var_intervals) {} + + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + using ir::IRMutator<>::Visit; + +#define __(op__) \ + void Visit(const op__* op, Expr* expr) override { PartialSimplify(expr, var_intervals); } + + __(Add) + __(Mul) + __(Sub) + __(Div) + __(Min) + __(Max) +#undef __ + + void Visit(const Ramp* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(common::IsPureMath(node->base)); + CHECK(common::IsPureMath(node->stride)); + PartialSimplify(&node->base, var_intervals); + PartialSimplify(&node->stride, var_intervals); + } + + void Visit(const Cast* op, Expr* expr) override { + auto* node = expr->As(); + Visit(&node->v(), &node->v()); + } + + void Visit(const PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + node->condition = common::SolveInequality(op->condition, op->iterator); + + Visit(&node->body, &node->body); + } + + void Visit(const For* op, Expr* expr) override { + auto* node = expr->As(); + Visit(&node->min, &node->min); + Visit(&node->extent, &node->extent); + auto* min_i = op->min.As(); + auto* extent_i = op->extent.As(); + if (min_i && extent_i && extent_i->value > min_i->value) { + var_intervals.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + } else { + var_intervals.emplace(op->loop_var->name, common::CasInterval{op->min, op->extent - 1}); + } + + Visit(&node->body, &node->body); + if (min_i && extent_i) { + var_intervals.erase(op->loop_var->name); + } + } + + void Visit(const _Tensor_* op, Expr* expr) override { + auto* node = expr->As(); + + for (auto& e : node->shape) { + PartialSimplify(&e, var_intervals); + } + for (auto& e : node->domain) { + PartialSimplify(&e, var_intervals); + } + } +}; + +struct SimplifyLoadMutator : public ir::IRMutator { + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + void Visit(const Load* expr, Expr* op) override { + auto* node = op->As(); + for (auto& idx : node->indices) { + if (common::IsPureMath(idx)) { + PartialSimplify(&idx, var_intervals_); + } else { + SimplifyButStoreLoadMutator mutator(var_intervals_); + mutator(&idx); + } + } + } + + void Visit(const For* op, Expr* expr) override { + auto* min_i = op->min.As(); + auto* extent_i = op->extent.As(); + if (min_i && extent_i && extent_i->value > min_i->value) { + var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + } + + auto* node = expr->As(); + + operator()(&node->body); + operator()(&node->extent); + + if (min_i && extent_i) { + var_intervals_.erase(op->loop_var->name); + } + } + + common::cas_intervals_t var_intervals_; +}; + +struct SimplifyStoreMutator : public ir::IRMutator { + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + void Visit(const Store* expr, Expr* op) override { + auto* node = op->As(); + + for (auto& idx : node->indices) { + if (common::IsPureMath(idx)) { + PartialSimplify(&idx, var_intervals_); + } else { + SimplifyButStoreLoadMutator mutator(var_intervals_); + mutator(&idx); + } + } + } + + void Visit(const For* op, Expr* expr) override { + auto* min_i = op->min.As(); + auto* extent_i = op->extent.As(); + if (min_i && extent_i) { + var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1}); + } + + auto* node = expr->As(); + + operator()(&node->body); + operator()(&node->extent); + + if (min_i && extent_i) { + var_intervals_.erase(op->loop_var->name); + } + } + + common::cas_intervals_t var_intervals_; +}; + +struct SimplifyRampMutator : public ir::IRMutator { + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + void Visit(const Ramp* op, Expr* expr) override { + auto* node = expr->As(); + + CHECK(common::IsPureMath(node->base)) << node->base << "is not a pure math!"; + CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!"; + ; + Simplify(&node->base); + Simplify(&node->stride); + } + // ramp + ramp + void Visit(const Add* op, Expr* expr) override { + auto* node = expr->As(); + Expr a = node->a(); + Expr b = node->b(); + auto a_ramp = a.As(); + auto b_ramp = b.As(); + + if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) { + Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base); + Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride); + *expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes); + } + } +}; + +struct SimplifyIfThenElseMutator : public ir::IRMutator<> { + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + using ir::IRMutator<>::Visit; + + void Visit(const IfThenElse* op, Expr* expr) override { + auto* node = expr->As(); + node->condition = common::AutoSimplify(node->condition); + + if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); + if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); + } +}; + +struct ReplaceFracWithDivMutator : public ir::IRMutator<> { + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + void Visit(const FracOp* op, Expr* expr) override { + auto* node = expr->As(); + + ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0)); + ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1)); + + *expr = ir::Div::Make(node->operand(0), node->operand(1)); + } +}; + +struct SimplifyBlocksMutator : public ir::IRMutator<> { + explicit SimplifyBlocksMutator() {} + + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + using ir::IRMutator<>::Visit; + + void Visit(const Block* op, Expr* expr) override { + auto* node = expr->As(); + + if (node->stmts.size() == 1 && node->stmts[0].As()) { + VLOG(6) << "Simplify size-1 ir::Block"; + *expr = node->stmts[0]; + Visit(expr, expr); + } else { + for (auto& s : node->stmts) { + Visit(&s, &s); + } + std::vector stmts; + for (auto& s : node->stmts) { + if (s.As()) { + VLOG(6) << "Simplify ir::Block inside ir::Block"; + auto inner_block = s.As(); + for (auto inner_stmt : inner_block->stmts) { + stmts.push_back(inner_stmt); + } + } else { + stmts.push_back(s); + } + } + expr->As()->stmts = stmts; + } + } + + void Visit(const IfThenElse* op, Expr* expr) override { + if (op->condition.As()) { + if (op->condition.as_bool() == false) { + VLOG(6) << "Simplify ir::IfThenElse false block"; + if (expr->As()->false_case.defined()) { + *expr = expr->As()->false_case; + } else { + *expr = ir::Block::Make({}); + } + } else { + if (expr->As()->true_case.defined()) { + VLOG(6) << "Simplify ir::IfThenElse true block"; + *expr = expr->As()->true_case; + } else { + *expr = ir::Block::Make({}); + } + } + ir::IRMutator::Visit(expr, expr); + return; + } + ir::IRMutator::Visit(op, expr); + } +}; + +struct SimplifyForLoopsMutator : public ir::IRMutator<> { + absl::flat_hash_map var_intervals; + explicit SimplifyForLoopsMutator() {} + + void operator()(Expr* x) { ir::IRMutator::Visit(x, x); } + + using ir::IRMutator<>::Visit; + + void Visit(const For* op, Expr* expr) override { + auto* node = expr->As(); + Visit(&node->min, &node->min); + Visit(&node->extent, &node->extent); + auto* min_i = node->min.As(); + auto* extent_i = node->extent.As(); + if (min_i && extent_i && extent_i->value > min_i->value && extent_i->value - min_i->value == 1) { + VLOG(6) << "Simplify current For Loop"; + std::string var_name = node->loop_var->name; + var_intervals.emplace(var_name, common::CasInterval{min_i->value, extent_i->value - 1}); + if (node->body.As() && node->body.As()->stmts.size() == 1) { + *expr = node->body.As()->stmts[0]; + } else { + *expr = node->body; + } + Visit(expr, expr); + var_intervals.erase(var_name); + } else { + Visit(&node->body, &node->body); + } + } + + void Visit(const _Var_* op, Expr* expr) override { + auto* node = expr->As(); + + if (var_intervals.count(node->name)) { + auto loop_range = var_intervals.at(node->name); + *expr = Expr(loop_range.l); + } + } +}; + +} // namespace + +void Simplify(Expr* expr) { + VLOG(3) << "Begin Simplify " << *expr; + optim::CastSimplify(expr); + SimplifyRampMutator()(expr); + SimplifyLoadMutator()(expr); + SimplifyStoreMutator()(expr); + SimplifyIfThenElseMutator()(expr); + + common::cas_intervals_t var_intervals; + SimplifyButStoreLoadMutator mutator(var_intervals); + mutator(expr); + + ReplaceFracWithDivMutator()(expr); +} + +void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); } +void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_simplify.h b/paddle/cinn/optim/ir_simplify.h new file mode 100644 index 0000000000000..f5e2bdf82f6ba --- /dev/null +++ b/paddle/cinn/optim/ir_simplify.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Simplify the expression. + * The following cases are supported: + * a + 0 => a + * a*0 => 0 + * A[i*0+2*a+3*a+1+2] => A[5*a+3] + * + * This only works on the simple IR nodes such as Load, Store, and the math operators such as Add, Sub and so on. + */ +void Simplify(Expr *expr); + +void SimplifyForLoops(Expr *expr); + +void SimplifyBlocks(Expr *expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/ir_simplify_test.cc b/paddle/cinn/optim/ir_simplify_test.cc new file mode 100755 index 0000000000000..b9e7fb807a072 --- /dev/null +++ b/paddle/cinn/optim/ir_simplify_test.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/ir_simplify.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { +using utils::GetStreamCnt; +using utils::Trim; + +TEST(IrSimplify, basic) { + auto A = Compute( + {Expr(100), Expr(20)}, [&](Var i, Var j) { return Expr(1.f); }, "C"); + Buffer A_buf(A->type()); + A->Bind(A_buf); + + Var i("i"), j("j"); + i->set_type(Int(32)); + j->set_type(Int(32)); + + { // simple case + auto B = A(i, Expr(0)) + 1.f * 0.f + 100.f + 24.5f; + + LOG(INFO) << "B " << B; + // get (((C[(i * 20)] + 0) + 100) + 24.5) + Simplify(&B); + LOG(INFO) << "simplified: " << B; + auto out = "(124.500000f + C[i, 0])"; + EXPECT_EQ(out, utils::GetStreamCnt(B)); + } + + { + Placeholder x("X", {100, 20}); + Placeholder y("y", {100, 20}); + + auto B = Compute( + {Expr(100), Expr(20)}, + [&](Expr i, Expr j) { + return x(i + 0, j + 0) + y(i, j * 0) * 1.f + 0.f * x(i, j) + 25.f + 100.f - 0.f + + 9.f * 10000.f * 1.f * 1.f * 0.f; + }, + "B"); + + auto stages = CreateStages({B}); + auto func = Lower("func", stages, {B}); + auto body = func->body; + + LOG(INFO) << "original body:\n" << body; + Simplify(&body); + auto target_out = R"ROC( +{ + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + B[i, j] = (125.000000f + (X[i, j] + y[i, 0])) + } + } +} +)ROC"; + EXPECT_EQ(Trim(target_out), Trim(GetStreamCnt(body))); + } + + { + Placeholder x("X", {100, 20}); + Placeholder y("y", {100, 20}); + + auto B = Compute( + {Expr(100), Expr(20)}, + [&](Expr i, Expr j) { + return x(100 * 10 * 1 * i + 0, j * 0) + y(i, j * 0) / (1.f + 2.f) + 0.f * x(i, j) + 25.f + 100.f - 0.f + + 9.f * 10000.f * 1.f * 1.f * 0.f; + }, + "B"); + + auto stages = CreateStages({B}); + + auto func = Lower("func", stages, {B}); + auto body = func->body; + + LOG(INFO) << "original body:\n" << body; + Simplify(&body); + + auto target_out = R"ROC( +{ + serial for (i, 0, 100) + { + serial for (j, 0, 20) + { + B[i, j] = ((y[i, 0] / 3.00000000f) + (125.000000f + X[(1000 * i), 0])) + } + } +} +)ROC"; + EXPECT_EQ(Trim(target_out), Trim(GetStreamCnt(body))); + } +} + +TEST(reverse, prod) { + Expr M(100), N(20); + Placeholder A("A", {M, N}); + auto C = Compute( + {M, N}, [=](Var i, Var j) { return Expr(1.f) / A(i, j); }, "C"); + + auto stages = CreateStages({A, C}); + auto fn = Lower("fn", stages, {A, C}); + LOG(INFO) << "fn:\n" << fn; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/lower_function_call_bind_vars.cc b/paddle/cinn/optim/lower_function_call_bind_vars.cc new file mode 100644 index 0000000000000..abb80fc56b871 --- /dev/null +++ b/paddle/cinn/optim/lower_function_call_bind_vars.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/lower_function_call_bind_vars.h" + +#include +#include + +#include "cinn/ir/ir_mutator.h" + +namespace cinn { +namespace optim { + +namespace { + +struct LowerFunctionCallBindVarsMutator : public ir::IRMutator<> { + LowerFunctionCallBindVarsMutator() = default; + + void operator()(Expr* m) { + m_ = m->as_module(); + Expr module(m->get()); + ir::IRMutator<>::Visit(&module, &module); + } + + private: + void Visit(const ir::Call* op, Expr* expr) { + auto* node = expr->As(); + if (op->is_cinn_call()) { + const std::string& target = op->name; + auto it = std::find_if(m_->functions.begin(), m_->functions.end(), [&](const Expr& x) { + return x.as_lowered_func()->name == target; + }); + CHECK(it != m_->functions.end()) << "The called function [" << target << "] is not exist"; + + std::vector extra_var_args; + + for (auto& arg : (*it).as_lowered_func()->args) { + if (arg.is_var()) { + extra_var_args.push_back(arg.var_arg()); + } + } + + // insert the extra var arguments to the begining of the original call's argument list. + node->read_args.insert(std::begin(op->read_args), extra_var_args.begin(), extra_var_args.end()); + } + + ir::IRMutator<>::Visit(op, expr); + } + + private: + ir::_Module_* m_{}; +}; + +} // namespace + +void LowerFunctionCallBindVars(Expr* m) { + CHECK(m->as_module()); + LowerFunctionCallBindVarsMutator()(m); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/lower_function_call_bind_vars.h b/paddle/cinn/optim/lower_function_call_bind_vars.h new file mode 100644 index 0000000000000..d5b941862a9c7 --- /dev/null +++ b/paddle/cinn/optim/lower_function_call_bind_vars.h @@ -0,0 +1,26 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cinn/ir/ir.h" +#include "cinn/ir/module.h" + +namespace cinn { +namespace optim { + +void LowerFunctionCallBindVars(Expr *m); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/lower_intrin.cc b/paddle/cinn/optim/lower_intrin.cc new file mode 100644 index 0000000000000..e342af8fbeb22 --- /dev/null +++ b/paddle/cinn/optim/lower_intrin.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/lower_intrin.h" + +#include + +#include "cinn/backends/llvm/llvm_intrin_rule.h" +#include "cinn/cinn.h" +#include "cinn/ir/intrinsic_ops.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/registry.h" + +namespace cinn { +namespace optim { + +void LowerIntrin(Expr *e, Target target) { + if (target.arch == Target::Arch::X86) { + codegen::RegisterCpuIntrinRule(); + } else { + return; + } + struct Mutator : ir::IRMutator { + Target target; + + explicit Mutator(Target target) : target(target) {} + + void operator()(Expr *e) { ir::IRMutator<>::Visit(e, e); } + + void Visit(const ir::Add *op, Expr *expr) override { + auto *node = expr->As(); + CHECK(node); + Expr ret; + if (node->type().is_float()) { + if (const ir::Mul *mul = node->b().As()) { + ret = ir::Call::Make(node->type(), "fma", {mul->a(), mul->b(), node->a()}, {}, ir::CallType::Intrinsic); + } else if (const ir::Mul *mul = node->a().As()) { + ret = ir::Call::Make(node->type(), "fma", {mul->a(), mul->b(), node->b()}, {}, ir::CallType::Intrinsic); + } + if (ret.defined()) { + ir::IRMutator<>::Visit(&ret, &ret); + *expr = ret; + return; + } + } + ir::IRMutator<>::Visit(&node->a(), &node->a()); + ir::IRMutator<>::Visit(&node->b(), &node->b()); + } + + void Visit(const ir::Call *op, Expr *expr) override { + auto *node = expr->As(); + CHECK(node); + LowerCpuintrinsicOp(node, expr); + } + + void LowerCpuintrinsicOp(ir::Call *op, Expr *expr) { + auto *node = expr->As(); + if (kIntrinsicCalls.count(node->name)) { + CHECK(!node->name.empty()); + auto *func_ptr = ir::Registry::Get("lower_cpu_intrinsic_" + node->name); + CHECK(func_ptr) << "find no rule to lower cpu intrinsic for " + << "lower_cpu_intrinsic_" + node->name; + Expr ret = (*func_ptr)(Expr(node)); + if (!ret.same_as(*expr)) { + ir::IRMutator<>::Visit(&ret, &ret); + } + *expr = ret; + return; + } + for (auto &expr : node->read_args) { + ir::IRMutator<>::Visit(&expr, &expr); + } + for (auto &expr : node->write_args) { + ir::IRMutator<>::Visit(&expr, &expr); + } + } + }; + + Mutator m(target); + m(e); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/lower_intrin.h b/paddle/cinn/optim/lower_intrin.h new file mode 100644 index 0000000000000..1b4b5cd2ac42d --- /dev/null +++ b/paddle/cinn/optim/lower_intrin.h @@ -0,0 +1,41 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +static const std::set kIntrinsicCalls{ + {"exp", "exp2", "sqrt", "log", "log2", "log10", "floor", + "ceil", "round", "trunc", "cos", "cosh", "tan", "tanh", + "sin", "sinh", "fabs", "isnan", "isfinite", "isinf", "left_shift", + "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not", "fma", "rsqrt"}}; + +/** + * Map the Call nodes to llvm intrinsic. + * + * This will rename the external call with the function in different backends. + * + * Notes: only support cpu currently. + */ +void LowerIntrin(Expr *e, Target target); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/map_extern_call.cc b/paddle/cinn/optim/map_extern_call.cc new file mode 100644 index 0000000000000..f67129a64567b --- /dev/null +++ b/paddle/cinn/optim/map_extern_call.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/map_extern_call.h" + +#include "cinn/cinn.h" +#include "cinn/hlir/op/op_util.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/runtime/cpu/host_intrinsics.h" + +namespace cinn { +namespace optim { + +static const std::set kExternFp32CallsGPU{ + {"exp", "erf", "sigmoid", "sqrt", "log", "log2", "log10", "floor", "ceil", "round", "trunc", + "cos", "cosh", "tan", "sin", "sinh", "acos", "acosh", "asin", "asinh", "atan", "atanh", + "isnan", "tanh", "isfinite", "isinf", "remainder", "rsqrt", "cbrt", "abs", "pow", "mod"}}; + +static const std::set kExternInt32CallsGPU{{"left_shift", + "right_shift", + "bitwise_or", + "bitwise_and", + "bitwise_xor", + "bitwise_not", + "pow", + "logical_right_shift", + "clz", + "popc", + "mod"}}; + +static const std::set kExternFp32CallsCPU = { + "erf", "acos", "acosh", "asin", "asinh", "atan", "atanh", "remainder"}; + +void MapExternCall(Expr *e, Target target) { + struct Mutator : ir::IRMutator { + Target target; + + explicit Mutator(Target target) : target(target) {} + + void operator()(Expr *e) { ir::IRMutator<>::Visit(e, e); } + + void Visit(const ir::Call *op, Expr *expr) override { + auto *node = expr->As(); + CHECK(node); + OptimizeConstantPow(node); + if (target.arch == Target::Arch::NVGPU) { + DealWithNvGpuintrinsics(node, expr); + } else { + DealWithCpuintrinsics(node, expr); + } + } + + void DealWithCpuintrinsics(ir::Call *node, Expr *expr) { + if (kExternFp32CallsCPU.count(node->name)) { + CHECK_GE(node->read_args.size(), 1UL); + CHECK(node->read_args.front().type().is_float()) + << "CPU extern call instrinsices only support float now! Please check."; + if (node->read_args.front().type().is_float(32)) { + auto out_type = node->type(); + *expr = lang::CallExtern(node->name + "f", node->read_args); + } + } + } + + void DealWithNvGpuintrinsics(ir::Call *node, Expr *expr) { + auto arg_size = node->read_args.size(); + if (arg_size == 0UL) { + // some node like __syncthreads hasn't arguments + return; + } + const auto &dtype = node->read_args.front().type(); + const auto &name = node->name; + + bool node_in_extern_fp32 = kExternFp32CallsGPU.count(name); + bool node_in_extern_int32 = kExternInt32CallsGPU.count(name); + if (!node_in_extern_fp32 && !node_in_extern_int32) { + return; + } + + std::string extern_func = hlir::GetExternFuncName(common::DefaultNVGPUTarget(), dtype, name); + *expr = lang::CallExtern(extern_func, node->read_args, node->attrs); + } + + // Replace pow(x, 0.5) to sqrt(x) and pow(x, -0.5) to rsqrt(x), which + // can speed up a lot. + // + // Reference: + // https://en.wikipedia.org/wiki/Fast_inverse_square_root + void OptimizeConstantPow(ir::Call *node) { + if (node->name == "pow" && node->read_args.size() >= 2 && node->read_args[1].is_constant()) { + float pow_constant = node->read_args[1].get_constant(); + if (pow_constant == 0.5) { + node->name = "sqrt"; + node->read_args.erase(node->read_args.begin() + 1); + } else if (pow_constant == -0.5) { + node->name = "rsqrt"; + node->read_args.erase(node->read_args.begin() + 1); + } + } + } + }; + + Mutator m(target); + m(e); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/map_extern_call.h b/paddle/cinn/optim/map_extern_call.h new file mode 100644 index 0000000000000..6ece28f96bad6 --- /dev/null +++ b/paddle/cinn/optim/map_extern_call.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Map the Call nodes to external function call. + * + * This will rename the external call with the function in different backends. + */ +void MapExternCall(Expr *e, Target target); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc new file mode 100644 index 0000000000000..55ddc705700de --- /dev/null +++ b/paddle/cinn/optim/optimize.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/optimize.h" + +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_schedule_util.h" +#include "cinn/optim/call_arg_list_to_pod_value.h" +#include "cinn/optim/cast_bool_to_int8.h" +#include "cinn/optim/cast_simplify.h" +#include "cinn/optim/eliminate_broadcast_in_forloop.h" +#include "cinn/optim/extern_call_process.h" +#include "cinn/optim/fold_cinn_call_arguments.h" +#include "cinn/optim/if_simplify.h" +#include "cinn/optim/insert_debug_log_callee.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/lower_function_call_bind_vars.h" +#include "cinn/optim/lower_intrin.h" +#include "cinn/optim/map_extern_call.h" +#include "cinn/optim/remove_nested_block.h" +#include "cinn/optim/remove_schedule_block.h" +#include "cinn/optim/replace_const_param_to_integer.h" +#include "cinn/optim/transform_gpu_forloop.h" +#include "cinn/optim/transform_polyfor_to_for.h" +#include "cinn/optim/unroll_loops.h" +#include "cinn/optim/vectorize_loops.h" + +DECLARE_bool(cinn_ir_schedule); + +namespace cinn { +namespace optim { + +Expr Optimize(Expr e, Target target, bool runtime_debug_info, bool remove_gpu_for_loops) { + CHECK(e.defined()); + auto copied = IRCopy(e); + + FoldCINNCallArguments(&copied); + TransformPolyForToFor(&copied); + ReplaceConstParamToInteger(&copied); + // Simplify already contains CastSimplify + Simplify(&copied); + UnrollLoop(&copied); + VLOG(4) << "After Optimize UnrollLoop:" << copied; + + VectorizeLoops(&copied, target); + VLOG(4) << "After Optimize VectorizeLoops:" << copied; +#ifdef CINN_WITH_CUDA + if (FLAGS_cinn_ir_schedule && copied.as_lowered_func()) { + ir::SetCudaAxisInfo(&copied); + } + if (remove_gpu_for_loops) { + RemoveGpuForloopsAxis(&copied); + } + CudaSyncThreadsDropIfThenElse(&copied); +#endif + + RemoveNestedBlock(&copied); + VLOG(4) << "After Optimize RemoveNestedBlock:" << copied; + + MapExternCall(&copied, target); + VLOG(10) << "After Optimize MapExternCall:" << copied; + + ExternCallMultiOutputShallowStore(&copied); + VLOG(10) << "After Optimize ExternCallMultiOutputShallowStore:" << copied; + // Simplify already contains CastSimplify + Simplify(&copied); + VLOG(10) << "After Optimize Simplify:" << copied; + + IfSimplify(&copied); + VLOG(10) << "After Optimize IfSimplify:" << copied; + + if (runtime_debug_info) { + LOG(WARNING) << "Turn on runtime debug information output"; + InsertDebugLogCallee(&copied); + } + return copied; +} + +ir::Module Optimize(const ir::Module& module, const Target& target) { + auto copied = IRCopy(Expr(module)); + if (FLAGS_cinn_ir_schedule) { + UnrollLoop(&copied); + VectorizeLoops(&copied, Target()); + } + VLOG(10) << "After VectorizeLoops:" << copied.as_module_ref(); + RemoveScheduleBlock(&copied); + VLOG(10) << "After RemoveScheduleBlock:" << copied.as_module_ref(); + LowerFunctionCallBindVars(&copied); + VLOG(10) << "After LowerFunctionCallBindVars:" << copied.as_module_ref(); + CallArgListToPodValue(&copied); + VLOG(10) << "After CallArgListToPodValue:" << copied.as_module_ref(); + LowerIntrin(&copied, target); + VLOG(10) << "After LowerIntrin:" << copied.as_module_ref(); + + return copied.as_module_ref(); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/optimize.h b/paddle/cinn/optim/optimize.h new file mode 100644 index 0000000000000..7d1165f3d883c --- /dev/null +++ b/paddle/cinn/optim/optimize.h @@ -0,0 +1,36 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" +#include "cinn/ir/module.h" + +namespace cinn { +namespace optim { + +/** + * Optimize the expression but Module. + * @param e + * @param runtime_debug_info + * @return + */ +Expr Optimize(Expr e, Target target, bool runtime_debug_info = false, bool remove_gpu_for_loops = true); + +/** + * Optimize a Module. + */ +ir::Module Optimize(const ir::Module& module, const Target& target); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/optimize_test.cc b/paddle/cinn/optim/optimize_test.cc new file mode 100755 index 0000000000000..1479fa6b37871 --- /dev/null +++ b/paddle/cinn/optim/optimize_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/optimize.h" + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { + +TEST(Optimize, Unroll) { + Placeholder A("A", {100, 20}); + + auto C = Compute( + {Expr(100), Expr(20)}, [&](Var i, Var j) { return A(i, j) + 1.f; }, "C"); + auto stages = CreateStages({C}); + + stages[C]->Split(1, 5); + stages[C]->Unroll(2); + + auto func = Lower("matmul", stages, {A, C}); + + auto out = R"ROC( +{ + serial for (i, 0, 100) + { + serial for (j_outer, 0, 4) + { + C[i, (5 * j_outer)] = (1.00000000f + A[i, (5 * j_outer)]) + C[i, (1 + (5 * j_outer))] = (1.00000000f + A[i, (1 + (5 * j_outer))]) + C[i, (2 + (5 * j_outer))] = (1.00000000f + A[i, (2 + (5 * j_outer))]) + C[i, (3 + (5 * j_outer))] = (1.00000000f + A[i, (3 + (5 * j_outer))]) + C[i, (4 + (5 * j_outer))] = (1.00000000f + A[i, (4 + (5 * j_outer))]) + } + } +} +)ROC"; + + EXPECT_EQ(utils::Trim(out), utils::Trim(utils::GetStreamCnt(func->body))); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block.cc b/paddle/cinn/optim/remove_nested_block.cc new file mode 100644 index 0000000000000..366dd23a1a33f --- /dev/null +++ b/paddle/cinn/optim/remove_nested_block.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/remove_nested_block.h" + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" + +namespace cinn { +namespace optim { + +Expr GetExprInsideBlock(Expr op) { + Expr node = op; + while (node.As()) { + auto& stmts = node.As()->stmts; + if (stmts.size() == 1) { + node = stmts.front(); + } else { + break; + } + } + return node; +} + +// This will remove the nested blocks, but it will also remove the block outside the forloop's body. +struct NestedBlockSimplifer : public ir::IRMutator { + void operator()(ir::Expr* expr) { Visit(expr); } + + private: + void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + if (node->stmts.size() == 1) { + *op = GetExprInsideBlock(*op); + IRMutator::Visit(op, op); + } else { + IRMutator::Visit(expr, op); + } + } +}; + +struct NestedBlockRemover : public ir::IRMutator { + void operator()(ir::Expr* expr) { Visit(expr); } + + private: + void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::Block* expr, Expr* op) override { + auto* node = op->As(); + + std::vector new_exprs; + + bool detect_nested = false; + for (auto it = node->stmts.begin(); it != node->stmts.end(); it++) { + auto* block = it->As(); + if (block) { + detect_nested = true; + new_exprs.insert(std::end(new_exprs), block->stmts.begin(), block->stmts.end()); + } else { + new_exprs.push_back(*it); + } + } + + node->stmts = new_exprs; + + IRMutator::Visit(expr, op); + } +}; + +// add block outside forloop's body. +struct AddBlockToForloop : public ir::IRMutator<> { + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::For* expr, Expr* op) override { + auto* node = op->As(); + if (!node->body.As()) { + node->body = ir::Block::Make({node->body}); + } + + ir::IRMutator<>::Visit(expr, op); + } + + void Visit(const ir::PolyFor* expr, Expr* op) override { + auto* node = op->As(); + if (!node->body.As()) { + node->body = ir::Block::Make({node->body}); + } + + ir::IRMutator<>::Visit(expr, op); + } + + void Visit(const ir::_LoweredFunc_* expr, Expr* op) override { + auto* node = op->As(); + if (!node->body.As()) { + node->body = ir::Block::Make({node->body}); + } + + ir::IRMutator<>::Visit(expr, op); + } +}; + +void RemoveNestedBlock(Expr* e) { + NestedBlockRemover()(e); + NestedBlockSimplifer()(e); + AddBlockToForloop()(e); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block.h b/paddle/cinn/optim/remove_nested_block.h new file mode 100644 index 0000000000000..cf6393fc863a1 --- /dev/null +++ b/paddle/cinn/optim/remove_nested_block.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * This file implements the strategy to remove the unnecessary nested block. + */ +#pragma once +#include + +#include "cinn/common/common.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Remove the unecessary nested block. + */ +void RemoveNestedBlock(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_nested_block_test.cc b/paddle/cinn/optim/remove_nested_block_test.cc new file mode 100644 index 0000000000000..a62689c7d1ea0 --- /dev/null +++ b/paddle/cinn/optim/remove_nested_block_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/remove_nested_block.h" + +#include + +#include +#include + +#include "cinn/ir/ir_printer.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { + +TEST(RemoveNestedBlock, basic) { + auto block0 = ir::Block::Make({Expr(1.f), Expr(1.f)}); + auto block1 = ir::Block::Make({block0}); + auto e = Expr(block1); + + std::string origin = utils::GetStreamCnt(e); + EXPECT_EQ(origin, utils::Trim(R"ROC( +{ + { + 1.00000000f + 1.00000000f + } +} + )ROC")); + + std::cout << "origin:\n" << e << std::endl; + + RemoveNestedBlock(&e); + + std::cout << "e:\n" << e << std::endl; + + EXPECT_EQ(utils::GetStreamCnt(e), utils::Trim(R"ROC( +{ + 1.00000000f + 1.00000000f +} + )ROC")); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_schedule_block.cc b/paddle/cinn/optim/remove_schedule_block.cc new file mode 100644 index 0000000000000..e496ccdca4f0f --- /dev/null +++ b/paddle/cinn/optim/remove_schedule_block.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/remove_schedule_block.h" + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace optim { + +struct ScheduleBlockRemover : public ir::IRMutator { + void operator()(ir::Expr* expr) { Visit(expr); } + + private: + void Visit(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + auto& iter_values = node->iter_values; + auto* schedule_block = node->schedule_block.As(); + CHECK(schedule_block); + auto& iter_vars = schedule_block->iter_vars; + Expr body = schedule_block->body; + CHECK_EQ(iter_vars.size(), iter_values.size()); + for (int i = 0; i < iter_vars.size(); i++) { + optim::ReplaceVarWithExpr(&body, iter_vars[i], iter_values[i]); + } + *expr = body; + IRMutator::Visit(expr, expr); + } +}; + +void RemoveScheduleBlock(Expr* e) { ScheduleBlockRemover()(e); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_schedule_block.h b/paddle/cinn/optim/remove_schedule_block.h new file mode 100644 index 0000000000000..791c12159f81f --- /dev/null +++ b/paddle/cinn/optim/remove_schedule_block.h @@ -0,0 +1,33 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * This file implements the strategy to remove the unnecessary nested block. + */ +#pragma once +#include + +#include "cinn/common/common.h" +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Remove schedule block. + */ +void RemoveScheduleBlock(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/remove_schedule_block_test.cc b/paddle/cinn/optim/remove_schedule_block_test.cc new file mode 100755 index 0000000000000..bf41b729ea900 --- /dev/null +++ b/paddle/cinn/optim/remove_schedule_block_test.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/remove_schedule_block.h" + +#include + +#include +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { + +TEST(RemovescheduleBlock, basic) { + using namespace ir; // NOLINT + Context::Global().ResetNameId(); + Placeholder A("A", {Expr(100), Expr(20)}); + Placeholder B("B", {Expr(20), Expr(50)}); + Target target = common::DefaultHostTarget(); + Module::Builder builder("matmul", target); + // C = A * B + Var k(20, "k0"); + Tensor C = Compute( + {Expr(100), Expr(50)}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + auto stages = CreateStages({A, B, C}); + auto func = Lower("matmul", stages, {A, B, C}, {}, {}, nullptr, target, true); + LOG(INFO) << "func\n" << func; + + std::string origin = utils::GetStreamCnt(func); + EXPECT_EQ(origin, utils::Trim(R"ROC( +function matmul (_A, _B, _C) +{ + ScheduleBlock(root) + { + serial for (i, 0, 100) + { + serial for (j, 0, 50) + { + ScheduleBlock(C__reduce_init) + { + i0, i1 = axis.bind(i, j) + C__reduce_init[i0, i1] = 0.00000000f + } + serial for (k0, 0, 20) + { + ScheduleBlock(C) + { + i0_0, i1_0, i2 = axis.bind(i, j, k0) + C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0])) + } + } + } + } + } +} +)ROC")); + + RemoveScheduleBlock(&func->body); + + std::cout << "after RemovescheduleBlock:\n" << func << std::endl; + + EXPECT_EQ(utils::GetStreamCnt(func), utils::Trim(R"ROC( +function matmul (_A, _B, _C) +{ + serial for (i, 0, 100) + { + serial for (j, 0, 50) + { + C__reduce_init[i, j] = 0.00000000f + serial for (k0, 0, 20) + { + C[i, j] = (C[i, j] + (A[i, k0] * B[k0, j])) + } + } + } +} +)ROC")); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_call_with_expr.cc b/paddle/cinn/optim/replace_call_with_expr.cc new file mode 100644 index 0000000000000..ac69e484cec31 --- /dev/null +++ b/paddle/cinn/optim/replace_call_with_expr.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/replace_call_with_expr.h" + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/replace_var_with_expr.h" + +namespace cinn { +namespace optim { + +struct ReplaceCallWithExprModifier : public ir::IRMutator<> { + ReplaceCallWithExprModifier(const std::string &statement, const Expr &candidate) + : statement_(statement), candidate_(candidate) {} + + void operator()(Expr *e) { IRMutator<>::Visit(e, e); } + + private: + void Visit(const ir::Call *expr, Expr *op) override { + auto *node = op->As(); + CHECK(!node->name.empty()) << "Call has no name"; + VLOG(3) << "Processing Call node " << *op; + if (statement_ != node->name) return; + + Expr expr_candidate = IRCopy(candidate_); + VLOG(3) << "Original candidate expr: " << candidate_; + VLOG(3) << "Copied candidate expr: " << expr_candidate; + + // Replace the Call node with the expression candidate. + *op = expr_candidate; + VLOG(3) << "expr " << *op; + } + + private: + std::string statement_; + const Expr &candidate_; +}; + +void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate) { + ReplaceCallWithExprModifier modifier(statement, candidate); + modifier(e); +} + +void ReplaceIslCallWithExpr(Expr *e, + const std::string &statement, + const Expr &candidate, + const std::map &axis_map) { + VLOG(3) << "ReplaceCallWithExpr, original expression: " << candidate; + Expr copied = IRCopy(candidate); + // update the axis in the copied expression. + + // we treat the Store node as the normal statement, the others like Call node has no axis. + std::map local_axis; + std::vector origin_axes; + std::map new_axis_map = axis_map; + for (auto &item : axis_map) { + origin_axes.push_back(item.first); + } + // Add '_after' to the transformed var's name to avoid duplicating transforming. + // For example, given indices [i,j], if we want to switch 'i' and 'j'(i->j, j->i) + // When we don't add '_after', the processing will be : + // 1. [i,j] to [j,j] + // 2. [j,j] to [i,i] + // Then we get result [i,i], which is different form the correct result [j,i] + // If we add '_after', the processing will be: + // 1. [i,j] to [j_after,j] + // 2. [j_after,j] to [j_after,i_after] + // 3. [j_after,i_after] to [j, i] + // Mission Complete! + for (auto &item : new_axis_map) { + for (auto &axis : origin_axes) { + ReplaceVarWithExpr(&item.second, Var(axis), Expr(Var(axis + "_after"))); + } + } + if (copied.As()) { + auto *store = copied.As(); + for (int i = 0; i < store->indices.size(); i++) { + auto indice = store->indices[i]; + if (indice.is_var() || indice.is_constant()) { + if (!new_axis_map.count(std::to_string(i))) continue; + if (!indice.is_constant()) { + local_axis[indice.as_var()->name] = new_axis_map.at(std::to_string(i)); + } + } + } + // the store indices just contains the ones of transform's domain, not the range. + // e.g. { s[i,j] -> s[i0,i1,j]: i0=i/4 and i1=i%4 }, the store's indices just contains i,j while in the final code, + // the axis are from the range, that is, there are some new axis not exists in store->indice, i0 and i1. + } + + for (auto &laxis : local_axis) { + VLOG(3) << "local_axis Replacing axis: " << laxis.first << " to " << laxis.second; + ReplaceVarWithExpr(&copied, Var(laxis.first), laxis.second); + } + // replace the remaining axis(in the transform's range) + for (auto &item : new_axis_map) { + if (!local_axis.count(item.first)) { + VLOG(3) << "new_axis_map Replacing axis: " << item.first << " to " << item.second; + ReplaceVarWithExpr(&copied, Var(item.first), item.second); + } + } + + for (auto &axis : origin_axes) { + ReplaceVarWithExpr(&copied, Var(axis + "_after"), Expr(Var(axis))); + } + + VLOG(3) << "After replacing, the statement [" << statement << "] is : " << copied; + ReplaceCallWithExpr(e, statement, copied); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_call_with_expr.h b/paddle/cinn/optim/replace_call_with_expr.h new file mode 100644 index 0000000000000..470a4835038e8 --- /dev/null +++ b/paddle/cinn/optim/replace_call_with_expr.h @@ -0,0 +1,45 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Replace a Call node with a Expr (inline). + * @param e The expression to modify. + * @param statement The map from tuple_name to the expression candidate. + * @param candidate Var of each axis in the expression candidate. + */ +void ReplaceCallWithExpr(Expr *e, const std::string &statement, const Expr &candidate); + +/** + * Replace a Call node with a Expr (inline). + * @param e The expression to modify. + * @param statement The map from tuple_name to the expression candidate. + * @param candidate Var of each axis in the expression candidate. + * @param axis_map The map from a variable to expression. + */ +void ReplaceIslCallWithExpr(Expr *e, + const std::string &statement, + const Expr &candidate, + const std::map &axis_map); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_call_with_expr_test.cc b/paddle/cinn/optim/replace_call_with_expr_test.cc new file mode 100644 index 0000000000000..f5d08027a89d4 --- /dev/null +++ b/paddle/cinn/optim/replace_call_with_expr_test.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/replace_call_with_expr.h" + +#include + +#include "cinn/ir/buffer.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/placeholder.h" +#include "cinn/poly/ast_gen.h" + +namespace cinn { +namespace optim { + +using namespace poly; + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_const_param_to_integer.cc b/paddle/cinn/optim/replace_const_param_to_integer.cc new file mode 100644 index 0000000000000..9d270e4e8d9b6 --- /dev/null +++ b/paddle/cinn/optim/replace_const_param_to_integer.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/replace_const_param_to_integer.h" + +#include "cinn/ir/ir_mutator.h" +#include "cinn/poly/ast_gen.h" +#include "cinn/utils/string.h" + +namespace cinn::optim { + +namespace { + +struct Mutator : public ir::IRMutator<> { + using ir::IRMutator<>::Visit; + + void Visit(const ir::_Var_* op, Expr* expr) override { + if (utils::Startswith(op->name, poly::kIslParamConstPrefix)) { + std::string value = op->name.substr(strlen(poly::kIslParamConstPrefix)); + *expr = Expr(std::stoi(value)); + } + } +}; + +} // namespace + +void ReplaceConstParamToInteger(Expr* e) { + Mutator mutator; + mutator.Visit(e, e); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/replace_const_param_to_integer.h b/paddle/cinn/optim/replace_const_param_to_integer.h new file mode 100644 index 0000000000000..40b7dee5b3299 --- /dev/null +++ b/paddle/cinn/optim/replace_const_param_to_integer.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn::optim { + +/** + * Replace the constant parameter(included in ISL param) to the corresponding integer. + * + * e.g. + * + * The expression: + * for (int i = 0; i <= _const_0; i++) ... + * + * to + * + * for (int i = 0; i < 0; i++) + */ +void ReplaceConstParamToInteger(Expr* e); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/replace_var_with_expr.cc b/paddle/cinn/optim/replace_var_with_expr.cc new file mode 100644 index 0000000000000..c10a16bb60339 --- /dev/null +++ b/paddle/cinn/optim/replace_var_with_expr.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/replace_var_with_expr.h" + +#include "cinn/common/cas.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/tensor.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/replace_const_param_to_integer.h" + +namespace cinn { +namespace optim { + +struct ReplaceVarWithExprMutator : public ir::IRMutator<> { + ReplaceVarWithExprMutator(const Var& var, const Expr& expr, const std::string& tensor_name) + : var_(var), expr_(expr), tensor_name_(tensor_name) {} + + void operator()(Expr* expr) { + if (tensor_name_.empty()) visit_all_ = true; + IRMutator::Visit(expr, expr); + } + + private: + void Visit(const ir::_Var_* expr, Expr* op) override { + if (expr->name == var_->name && (do_replace_ || visit_all_)) { + auto copied = IRCopy(expr_); + *op = copied; + } + } + + void Visit(const ir::For* op, Expr* expr) override { + auto* node = expr->As(); + ir::IRMutator<>::Visit(&node->min, &node->min); + ir::IRMutator<>::Visit(&node->extent, &node->extent); + ir::IRMutator<>::Visit(&node->body, &node->body); + if (node->loop_var->name == var_->name && expr_.As() && visit_all_) { + node->loop_var = expr_.As(); + } + } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + ir::IRMutator<>::Visit(&node->init, &node->init); + ir::IRMutator<>::Visit(&node->condition, &node->condition); + ir::IRMutator<>::Visit(&node->inc, &node->inc); + ir::IRMutator<>::Visit(&node->body, &node->body); + if (node->iterator->name == var_->name && expr_.As() && visit_all_) { + node->iterator = expr_.As(); + } + } + + void Visit(const ir::Store* op, Expr* expr) override { + auto* node = expr->As(); + auto* tensor = node->tensor.as_tensor(); + + if (tensor->name == tensor_name_) { + do_replace_ = true; + } else { + do_replace_ = false; + } + for (auto& index : node->indices) { + ir::IRMutator<>::Visit(&index, &index); + } + do_replace_ = false; + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + ir::IRMutator<>::Visit(&node->value, &node->value); + } + + void Visit(const ir::Load* expr, Expr* op) override { + auto* node = op->As(); + auto* tensor = node->tensor.as_tensor(); + if (tensor->name == tensor_name_) { + do_replace_ = true; + } else { + do_replace_ = false; + } + for (auto& idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx); + do_replace_ = false; + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + } + + private: + bool do_replace_{false}; + bool visit_all_{false}; + const Var& var_; + const Expr& expr_; + const std::string& tensor_name_; +}; + +void ReplaceVarWithExpr(Expr* source, const Var& var, const Expr& expr, const std::string& tensor_name) { + ReplaceVarWithExprMutator mutator(var, expr, tensor_name); + mutator(source); +} + +struct CollectTensorIndexMutator : public ir::IRMutator<> { + CollectTensorIndexMutator(const std::string& tensor_name) : tensor_name_(tensor_name) {} + + std::vector> operator()(Expr* expr) { + IRMutator::Visit(expr, expr); + return res; + } + + private: + void Visit(const ir::For* op, Expr* expr) override { + auto* node = expr->As(); + ir::IRMutator<>::Visit(&node->body, &node->body); + } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + ir::IRMutator<>::Visit(&node->body, &node->body); + } + + void Visit(const ir::Load* expr, Expr* op) override { + auto* node = op->As(); + auto* tensor = node->tensor.as_tensor(); + if (tensor->name == tensor_name_) { + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + res.push_back(node->indices); + } else { + ir::IRMutator<>::Visit(&node->tensor, &node->tensor); + for (auto& idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx); + } + } + + private: + std::vector> res; + const std::string& tensor_name_; +}; + +std::vector> CollectTensorIndex(Expr* source, const std::string& tensor_name) { + CollectTensorIndexMutator mutator(tensor_name); + std::vector> result = mutator(source); + for (auto& i : result) { + for (auto& j : i) { + j = common::AutoSimplify(j); + } + } + return result; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/replace_var_with_expr.h b/paddle/cinn/optim/replace_var_with_expr.h new file mode 100644 index 0000000000000..50b2b2dd3ce31 --- /dev/null +++ b/paddle/cinn/optim/replace_var_with_expr.h @@ -0,0 +1,77 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Replace the variable with a expression. + * @param var The variable to replace. + * @param expr The candidate expression. + * @param tensor_name Name of the tensor whose indices will be edited. If it is empty, means we will + * do the replace in all Expr instead of only in specific tensor's indices. + */ +/** + * Example 1: ReplaceVarWithExpr(source, Var("i"), Expr(0), "A") + * for(i, 0, 10) + * for(j, 0, 10) + * B[i,j] = A[i,j] + * + * => + * + * for(i, 0, 10) + * for(j, 0, 10) + * B[i,j] = A[0,j] + * + * Example 2: ReplaceVarWithExpr(source, Var("i"), Expr(Var("k"))) + * for(i, 0, 10) + * for(j, 0, 10) + * B[i,j] = A[i,j] + * + * => + * + * for(k, 0, 10) + * for(j, 0, 10) + * B[k,j] = A[k,j] + */ +void ReplaceVarWithExpr(Expr *source, const Var &var, const Expr &expr, const std::string &tensor_name = ""); + +/** + * Collect the specific tensor's indices. + * @param tensor_name The specific tensor's name. + * @return Return a vector containing all the indices of the specific tensor appeared in source. + */ +/** + * Example: CollectTensorIndex(source, "A") + * for(i, 0, 10) + * for(j, 0, 10) + * C[i,j] = A[i,j] + A[0,j] + B[j,i] + B[i,0] + * + * => + * + * Return value: + * {{i,j},{0,j}} + */ +std::vector> CollectTensorIndex(Expr *source, const std::string &tensor_name); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/tensor_write_tell.cc b/paddle/cinn/optim/tensor_write_tell.cc new file mode 100644 index 0000000000000..d52590cf17d29 --- /dev/null +++ b/paddle/cinn/optim/tensor_write_tell.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/tensor_write_tell.h" + +namespace cinn { +namespace optim {} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/tensor_write_tell.h b/paddle/cinn/optim/tensor_write_tell.h new file mode 100644 index 0000000000000..a44664ca25baf --- /dev/null +++ b/paddle/cinn/optim/tensor_write_tell.h @@ -0,0 +1,54 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_mutator.h" + +namespace cinn { +namespace optim { + +struct TensorWriteTeller : public ir::IRMutator { + //! Collect the write info in \p op. + void Collect(const Expr* op) { Visit(op, op); } + + bool IsWrite(const std::string& tensor_name) const { return tensor_written.count(tensor_name); } + + private: + std::set tensor_written; + + void Visit(const Expr* expr, const Expr* op) override { IRMutator::Visit(expr, op); } + + void Visit(const ir::Store* expr, const Expr* op) override { + auto* node = op->As(); + CHECK(node); + auto* tensor = node->tensor.As(); + CHECK(tensor); + tensor_written.insert(tensor->name); + IRMutator::Visit(expr, op); + } + + void Visit(const ir::_Tensor_* op, const Expr* expr) override { + auto* node = expr->As(); + if (node->is_call_node()) { + tensor_written.insert(node->name); + } + } +}; + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc new file mode 100644 index 0000000000000..86e2572a7a70c --- /dev/null +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -0,0 +1,664 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/transform_gpu_forloop.h" + +#include +#include +#include +#include +#include + +#include "cinn/backends/cuda_util.h" +#include "cinn/common/cas.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/replace_var_with_expr.h" +#include "cinn/poly/isl_utils.h" +#include "cinn/poly/stage.h" +#include "cinn/runtime/intrinsic.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace optim { + +/** + * 1. Determine the grid and block dimensions. + * It takes the domains like `[0, 20]` or `[0, min(20, M/2)]`, the domain should have a integer right bound. + * + * 2. Replace the grid/thread iterators with something like `threadIdx.x`, `threadIdx.y`. + * + * 3. Remove the forloops owning the gpu axis. + * 1. if the extent is an IntImm, just remove this forloop. + * 2. if the extent is a Min, replace the forloop with an IfThenElse, with forloop's condition, new check will add (if + * the min of forloop is not zero). + * + * @param expr The expression to mutate. + */ +void RemoveGpuForloopsAxis(Expr *expr) { + struct Mutator : public ir::IRMutator { + void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::For *op, Expr *expr) override { + switch (op->for_type()) { + case ir::ForType::GPUBlock: + if (NeedToReplaceForloopWithIfThenElse(op)) { + ReplaceForloopWithIfThenElse(expr); + } else { + *expr = op->body; + } + IRMutator<>::Visit(expr, expr); + break; + case ir::ForType::GPUThread: + if (NeedToReplaceForloopWithIfThenElse(op)) { + ReplaceForloopWithIfThenElse(expr); + } else { + *expr = op->body; + } + IRMutator<>::Visit(expr, expr); + break; + default: + auto *node = expr->As(); + IRMutator<>::Visit(&node->body, &node->body); + break; + } + } + + bool NeedToReplaceForloopWithIfThenElse(const ir::For *n) const { return true; } + + void ReplaceForloopWithIfThenElse(Expr *expr) { + auto *for_n = expr->As(); + auto *poly_for_n = expr->As(); + CHECK(for_n || poly_for_n); + + Expr condition; + + auto condition_append = [&](Expr new_cond) { + if (condition.defined()) { + condition = ir::And::Make(condition, new_cond); + } else { + condition = new_cond; + } + }; + + if (for_n) { + // for(i, 2, 100); + // ^ + if (for_n->min != common::make_const(0)) { + condition_append(ir::GE::Make(for_n->loop_var, for_n->min)); + } + + // for(i, 2, min(M/2, 20) + // ^ + condition_append(ir::LT::Make(for_n->loop_var, for_n->extent)); + } else { + if (poly_for_n->init != common::make_const(0)) { + condition_append(ir::GE::Make(poly_for_n->iterator, poly_for_n->init)); + } + + condition_append(poly_for_n->condition); + } + + CHECK(condition.defined()); + + VLOG(3) << "GPU replacing\n" << *expr; + VLOG(3) << "\nto\n"; + auto if_n = ir::IfThenElse::Make(condition, for_n->body); + VLOG(3) << if_n; + *expr = if_n; + } + + void Visit(const ir::PolyFor *op, Expr *expr) override { + const auto msg = "PolyFor is not allowed for GPU, only For nodes are allowed"; + CHECK(op->for_type() != ir::ForType::GPUBlock) << msg; + CHECK(op->for_type() != ir::ForType::GPUThread) << msg; + CHECK(op->for_type() != ir::ForType::GPULane) << msg; + } + }; + + Mutator mutator; + mutator(expr); +} + +/** + * The generated __syncthreads call will be wrapped with a `if (xxxx == 0) { }`, this is the problem of isl AST output, + * drop it to make it run in all the threads. + */ +void CudaSyncThreadsDropIfThenElse(Expr *expr) { + struct Mutator : public ir::IRMutator<> { + void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::IfThenElse *op, Expr *expr) override { + blocked_statement_stack.push_back(expr); + ir::IRMutator<>::Visit(op, expr); + blocked_statement_stack.pop_back(); + } + + void Visit(const ir::Call *op, Expr *expr) override { + if (op->name == runtime::intrinsic::cuda_sync_threads) { + if (!blocked_statement_stack.empty()) { + auto *last_for = blocked_statement_stack.back()->As(); + if (auto *eq_n = last_for->condition.As()) { + if (eq_n->b() == common::make_const(0)) { + *blocked_statement_stack.back() = *expr; + } + } + } + } + } + + // Collect all the statements with Block(include Block) to the statement. + std::vector blocked_statement_stack; + }; + + Mutator()(expr); +} + +class RestructureVarNodes : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Load *load, Expr *op) override { + std::vector indices_copied; + for (const ir::Expr &indice : load->indices) { + indices_copied.push_back(IRCopy(indice)); + } + op->As()->indices = indices_copied; + + IRMutator::Visit(load, op); + } + + void Visit(const ir::Store *store, Expr *op) override { + std::vector indices_copied; + for (const ir::Expr &indice : store->indices) { + indices_copied.push_back(IRCopy(indice)); + } + op->As()->indices = indices_copied; + + IRMutator::Visit(store, op); + } +}; + +class ReplaceIndexToBindExpr : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::ScheduleBlockRealize *op, Expr *expr) override { + ir::ScheduleBlockRealize *schedule_block_realize = expr->As(); + CHECK(schedule_block_realize->schedule_block.As()); + std::vector iter_values = schedule_block_realize->iter_values; + ir::Expr body = schedule_block_realize->schedule_block.As()->body; + std::vector iter_vars = schedule_block_realize->schedule_block.As()->iter_vars; + + CHECK_EQ(iter_values.size(), iter_vars.size()); + for (int idx = 0; idx < iter_values.size(); ++idx) { + ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]); + } + ir::IRMutator<>::Visit(&body, &body); + } +}; + +using TENSOR_LOOP = std::pair>; +class CollectTensorLoopVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store *op, Expr *expr) override { + auto tensor = op->tensor.as_tensor_ref(); + // if buffer defined and buffer is not Heap. + if (tensor->buffer.defined() && tensor->buffer->memory_type != ir::MemoryType::Heap) { + if (buffer_tensor_loop_map_.count(tensor->buffer->name)) { + buffer_tensor_loop_map_[tensor->buffer->name].push_back(std::make_pair(*expr, loops_)); + } else { + buffer_tensor_loop_map_[tensor->buffer->name] = {std::make_pair(*expr, loops_)}; + } + } + + IRMutator::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + if (op->is_addr_scalar()) { + return; + } + auto tensor = op->tensor.as_tensor_ref(); + // if buffer defined and buffer is not Heap. + if (tensor->buffer.defined() && tensor->buffer->memory_type != ir::MemoryType::Heap) { + if (buffer_tensor_loop_map_.count(tensor->buffer->name)) { + buffer_tensor_loop_map_[tensor->buffer->name].push_back(std::make_pair(*expr, loops_)); + } else { + buffer_tensor_loop_map_[tensor->buffer->name] = {std::make_pair(*expr, loops_)}; + } + } + + IRMutator::Visit(op, expr); + } + + void Visit(const ir::For *op, Expr *expr) override { + loops_.push_back(*expr); + IRMutator::Visit(op, expr); + loops_.pop_back(); + } + + void Visit(const ir::PolyFor *op, Expr *expr) override { LOG(FATAL) << "Unkown PolyFor!"; } + + public: + std::vector loops_; + std::unordered_map> buffer_tensor_loop_map_; +}; + +void UpdateBufferAxisPass(ir::Expr *expr) { + CollectTensorLoopVisitor collect_tensor_loop_visitor; + collect_tensor_loop_visitor(expr); + + auto buffer_tensor_loop = collect_tensor_loop_visitor.buffer_tensor_loop_map_; + + for (auto &tmp : buffer_tensor_loop) { + auto tensor_loop_v = tmp.second; + + auto &front = tensor_loop_v.front(); + int count = tensor_loop_v.size() > 1 ? front.second.size() : 0; + for (int idx = 1; idx < tensor_loop_v.size(); ++idx) { + auto &other = tensor_loop_v[idx]; + for (int idy = 0; idy < std::min(front.second.size(), other.second.size()); ++idy) { + if (front.second[idy] != other.second[idy]) { + count = std::min(count, idy); + break; + } + } + } + + auto get_thread_bind_var = [](const std::vector &loops) { + // threadidx loop_var,extent. + using ThreadLoopVarExtentMap = std::unordered_map>; + ThreadLoopVarExtentMap thread_loop_var_exent_map; + for (auto loop : loops) { + auto loop_ir = loop.As(); + CHECK(loop_ir); + if (loop_ir->is_gpu_thread_binded()) { + std::string axis = ""; + if (loop_ir->bind_info().offset == 0) { + axis = "threadIdx.x"; + } else if (loop_ir->bind_info().offset == 1) { + axis = "threadIdx.y"; + } else { + axis = "threadIdx.z"; + } + // insert gpu thread loop var. + if (thread_loop_var_exent_map.count(axis)) { + auto &loop_var_extent = thread_loop_var_exent_map[axis]; + if (loop_var_extent.second >= loop_ir->extent.as_int32()) { + thread_loop_var_exent_map[axis] = std::make_pair(loop_ir->loop_var->name, loop_ir->extent.as_int32()); + } + } else { + thread_loop_var_exent_map[axis] = std::make_pair(loop_ir->loop_var->name, loop_ir->extent.as_int32()); + } + } + } + + std::unordered_set loop_var_map; + for (auto &tmp : thread_loop_var_exent_map) { + loop_var_map.insert(tmp.second.first); + } + + return loop_var_map; + }; + + auto load = front.first.As(); + auto store = front.first.As(); + auto tensor = load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref(); + // find store and load keep loop for shared + std::vector> keep_loop_vars; + if (tensor->buffer->memory_type == ir::MemoryType::GPUShared) { + for (auto &tensor_loop : tensor_loop_v) { + keep_loop_vars.push_back(get_thread_bind_var(tensor_loop.second)); + } + CHECK_EQ(keep_loop_vars.size(), tensor_loop_v.size()); + } + + auto &loops = front.second; + for (int idx = 0; idx < count; ++idx) { + auto loop_expr = loops[idx]; + auto loop_ir = loop_expr.As(); + auto loop_var = loop_ir->loop_var; + + for (int idy = 0; idy < tensor_loop_v.size(); ++idy) { + auto expr = tensor_loop_v[idy].first; + auto load = expr.As(); + auto store = expr.As(); + if (keep_loop_vars.size() == 0 || !keep_loop_vars[idy].count(loop_var->name)) { + auto &indices = load ? load->indices : store->indices; + for (auto &indice : indices) { + optim::ReplaceVarWithExpr(&indice, loop_var, ir::Expr(0)); + indice = common::AutoSimplify(indice); + } + } + } + } + } +} + +class ReplaceLoopVarToGpu : public ir::IRMutator<> { + public: + void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::For *op, Expr *expr) override { + auto for_ir = expr->As(); + CHECK(for_ir); + + auto bind_info = for_ir->bind_info(); + + std::string var_name = ""; + if (bind_info.offset == 0) + var_name = "x"; + else if (bind_info.offset == 1) + var_name = "y"; + else if (bind_info.offset == 2) + var_name = "z"; + if (for_ir->is_gpu_block_binded()) { + var_name = "blockIdx." + var_name; + optim::ReplaceVarWithExpr(expr, op->loop_var, ir::Expr(ir::Var(var_name))); + } else if (for_ir->is_gpu_thread_binded()) { + var_name = "threadIdx." + var_name; + optim::ReplaceVarWithExpr(expr, op->loop_var, ir::Expr(ir::Var(var_name))); + } + + ir::IRMutator<>::Visit(&for_ir->body, &for_ir->body); + } + void Visit(const ir::PolyFor *op, Expr *expr) override { LOG(FATAL) << "Unkown PolyFor!"; } +}; + +class SharedAxisVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store *op, Expr *expr) override { + auto store = expr->As(); + if (!store->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (store->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { + for (auto &indice : store->indices) { + for (auto axis : gpu_axis) { + optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + auto load = expr->As(); + if (load->is_addr_scalar()) { + return; + } + if (!load->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPUShared) { + for (auto &indice : load->indices) { + for (auto axis : gpu_axis) { + optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + } + ir::IRMutator<>::Visit(op, expr); + } + + const std::vector gpu_axis = {"blockIdx.x", "blockIdx.y", "blockIdx.z"}; +}; + +class LocalAxisVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store *op, Expr *expr) override { + auto store = expr->As(); + if (!store->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (store->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { + for (auto &indice : store->indices) { + for (auto axis : gpu_axis) { + optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + auto load = expr->As(); + if (load->is_addr_scalar()) { + return; + } + if (!load->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::GPULocal) { + for (auto &indice : load->indices) { + for (auto axis : gpu_axis) { + optim::ReplaceVarWithExpr(&indice, ir::Var(axis), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + } + ir::IRMutator<>::Visit(op, expr); + } + + const std::vector gpu_axis = { + "blockIdx.x", "blockIdx.y", "blockIdx.z", "threadIdx.x", "threadIdx.y", "threadIdx.z"}; +}; + +class ResizeBufferSizeVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store *op, Expr *expr) override { + auto store = expr->As(); + auto store_tensor = store->tensor.as_tensor_ref(); + + if (!store_tensor->buffer.defined()) { + return; + } + if (store_tensor->buffer->memory_type == ir::MemoryType::Heap) { + ir::IRMutator<>::Visit(op, expr); + return; + } + + auto &indices = store->indices; + auto &shape = store_tensor->shape; + auto &buffer = store_tensor->buffer->shape; + + shape.clear(); + buffer.clear(); + for (int idx = 0; idx < indices.size(); ++idx) { + shape.push_back(ir::Expr(BufferSize(indices[idx]))); + buffer.push_back(shape.back()); + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + auto load = expr->As(); + if (!load->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (load->tensor.as_tensor_ref()->buffer->memory_type == ir::MemoryType::Heap) { + ir::IRMutator<>::Visit(op, expr); + return; + } + + load->tensor.as_tensor_ref()->shape = load->tensor.as_tensor_ref()->buffer->shape; + + // For the moment, align the load tensor indices with the tensor shape using the trick method. + // A better way would be to modify the FlattenLoop Schedule. + int cnt = load->indices.size() - load->tensor.as_tensor_ref()->shape.size(); + for (int i = 0; i < cnt; i++) { + load->indices.erase(load->indices.begin()); + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::For *op, Expr *expr) override { + CHECK(expr->As()); + auto for_ir = expr->As(); + auto var_name = for_ir->loop_var->name; + auto extent_i = for_ir->extent; + + if (extent_i.is_constant()) loop_2_extent_[var_name] = extent_i.as_int32(); + ir::IRMutator<>::Visit(op, expr); + } + + int BufferSize(ir::Expr indice) { + auto copy = IRCopy(indice); + auto vars = ir::CollectIRNodesInOrder(copy, [](const ir::Expr *expr) { return expr->As(); }); + + int max_range = 1; + // using recursion funcitons index range. + std::function compute_range = [&](const int deep, ir::Expr index) { + auto var = vars[deep].as_var_ref(); + CHECK(loop_2_extent_.count(var->name)) << var->name; + auto extent = loop_2_extent_.find(var->name)->second; + + for (int idx = 0; idx < extent; ++idx) { + auto tmp = IRCopy(index); + ReplaceVarWithExpr(&tmp, var, Expr(idx)); + + if (deep == vars.size() - 1) { + auto simplify = common::AutoSimplify(tmp); + auto range = common::AutoSimplify(simplify); + CHECK(range.is_constant()); + max_range = std::max(max_range, range.as_int32() + 1); + } else { + compute_range(deep + 1, tmp); + } + } + }; + + if (vars.size()) compute_range(0, copy); + return max_range; + } + + std::unordered_map loop_2_extent_; +}; + +class ReplaceVarToZero : public ir::IRMutator<> { + public: + void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store *op, Expr *expr) override { + auto store = expr->As(); + if (!store->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + auto &indices = store->indices; + for (auto &indice : indices) { + for (auto var_ : loop_var_) { + optim::ReplaceVarWithExpr(&indice, ir::Var(var_), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load *op, Expr *expr) override { + auto load = expr->As(); + if (!load->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + auto &indices = load->indices; + for (auto &indice : indices) { + for (auto var_ : loop_var_) { + optim::ReplaceVarWithExpr(&indice, ir::Var(var_), ir::Expr(0)); + } + indice = common::AutoSimplify(indice); + } + + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::For *op, Expr *expr) override { + CHECK(expr->As()); + auto for_ir = expr->As(); + auto var_name = for_ir->loop_var->name; + auto extent_i = for_ir->extent; + + if (extent_i.is_constant() && extent_i.as_int32() == 1) loop_var_.insert(var_name); + ir::IRMutator<>::Visit(op, expr); + loop_var_.erase(var_name); + } + std::unordered_set loop_var_; +}; + +void OptimizeExprGPU(Expr *expr) { + VLOG(2) << "Before Optimize Expr:\n" << *expr; + + // copy var nodes to prevent one modification leading to multiple changes + RestructureVarNodes restructure_var_nodes; + restructure_var_nodes(expr); + + // replace var to bind expr + ReplaceIndexToBindExpr replace_index_to_bind_expr; + replace_index_to_bind_expr(expr); + + // resize buffer axis + UpdateBufferAxisPass(expr); + + // replace var name with block/thread + ReplaceLoopVarToGpu replace_loop_var_to_gpu; + replace_loop_var_to_gpu(expr); + + // update shared buffer axis + SharedAxisVisitor shared_axis_visitor; + shared_axis_visitor(expr); + + // update local buffer axis + LocalAxisVisitor local_axis_visitor; + local_axis_visitor(expr); + + ResizeBufferSizeVisitor resize_buffer_size_visitor; + resize_buffer_size_visitor(expr); + + ReplaceVarToZero replace_var_to_zero; + replace_var_to_zero(expr); + + VLOG(2) << "After Optimize Expr: \n" << *expr; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_gpu_forloop.h b/paddle/cinn/optim/transform_gpu_forloop.h new file mode 100644 index 0000000000000..bffe8f412c8a7 --- /dev/null +++ b/paddle/cinn/optim/transform_gpu_forloop.h @@ -0,0 +1,65 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/ir/lowered_func.h" +#include "cinn/poly/isl_utils.h" +#include "cinn/poly/stage.h" + +namespace cinn { +namespace optim { + +void OptimizeExprGPU(Expr* expr); +/* + // replace 'for' loop to gpu 'block/thread' + // update buffer index to save memory size. + // re-compute buffer size. +*/ + +/** + * Remove the forloops of block and thread axis, add the kernel launch thread dimension information to the outermost + * LoweredFunc. + * + * For example, input the code: + * \code + * // Note here, the outermost expression should be a LoweredFunc + * _LoweredFunc_: + * for (blockIdx.x, 0, 10) + * for (threadIdx.x, 0, 20) + * A(blockIdx.x, threadIdx.x) + * \endcode + * + * will be modified to + * \code + * _LoweredFunc_: + * A(blockIdx.x, threadIdx.x) + * \endcode + * + * \note For that the dimensions of each threadIdx or blockIdx should be constant, so this only takes For nodes, not + * \note PolyFor nodes is allowed to be GPU related. + */ +void RemoveGpuForloopsAxis(Expr* expr); + +/** + * Add __syncthreads() to shared memory producer. + */ +void CudaSyncThreadsDropIfThenElse(Expr* expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_polyfor_to_for.cc b/paddle/cinn/optim/transform_polyfor_to_for.cc new file mode 100644 index 0000000000000..3913056fbf719 --- /dev/null +++ b/paddle/cinn/optim/transform_polyfor_to_for.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/transform_polyfor_to_for.h" + +#include +#include + +#include "cinn/common/arithmatic.h" +#include "cinn/common/cas.h" +#include "cinn/common/ir_util.h" +#include "cinn/common/type.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/ir/ir_visitor.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_simplify.h" + +namespace cinn { +namespace optim { + +namespace { + +Expr PlusOneWithMinMax(Expr expr) { + auto* min_n = expr.As(); + auto* max_n = expr.As(); + + if (min_n) { + min_n->a() = min_n->a() + 1; + min_n->b() = min_n->b() + 1; + Simplify(&min_n->a()); + Simplify(&min_n->b()); + return expr; + } else if (max_n) { + max_n->a() = max_n->a() + 1; + max_n->b() = max_n->b() + 1; + Simplify(&max_n->a()); + Simplify(&max_n->b()); + return expr; + } + return expr + 1; +} + +struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + auto* ge_n = node->condition.As(); + auto* gt_n = node->condition.As(); + if (ge_n) { + node->condition = (ge_n->a() * -1) <= (ge_n->b() * -1); + } + if (gt_n) { + node->condition = (ge_n->a() * -1) < (ge_n->b() * -1); + } + + auto* lt_n = node->condition.As(); + auto* le_n = node->condition.As(); + + if (lt_n) { + if (lt_n->b() != common::make_const(0)) { + node->condition = lt_n->a() - lt_n->b() < 0; + } + } + if (le_n) { + if (le_n->b() != common::make_const(0)) { + node->condition = le_n->a() - le_n->b() <= 0; + } + } + + lt_n = node->condition.As(); + le_n = node->condition.As(); + if (!(lt_n || le_n)) return; + + // check the lhs is the iterator + bool can_extract_extent = (lt_n && lt_n->a().as_var() && lt_n->a().as_var()->name == op->iterator->name) || + (le_n && le_n->a().as_var() && le_n->a().as_var()->name == op->iterator->name); + + if (!can_extract_extent) { + if (node->condition.As()) { + auto le = node->condition.As(); + CHECK(le->a().As()); + CHECK_EQ(le->b().As()->value, 0UL); + auto sub = le->a().As(); + node->condition = ir::LE::Make(sub->a(), sub->b()); + } else if (node->condition.As()) { + auto lt = node->condition.As(); + CHECK(lt->a().As()); + CHECK_EQ(lt->b().As()->value, 0UL); + auto sub = lt->a().As(); + node->condition = ir::LT::Make(sub->a(), sub->b()); + } else { + LOG(FATAL) << "Unkown Type!"; + } + + lt_n = node->condition.As(); + le_n = node->condition.As(); + if (!(lt_n || le_n)) return; + } + + Expr lhs = lt_n ? lt_n->a() : le_n->a(); + Expr rhs = lt_n ? lt_n->b() : PlusOneWithMinMax(le_n->b()); + rhs = common::AutoSimplify(rhs); + + if (op->is_vectorized()) CHECK(op->vectorize_info().valid()); + + Expr new_for = + ir::For::Make(op->iterator, op->init, rhs, op->for_type(), op->device_api, op->body, op->vectorize_info()); + *expr = new_for; + + Visit(&new_for.As()->body); + } +}; + +} // namespace + +void TransformPolyForToFor(Expr* expr, bool auto_separate) { PolyForWithSimpleConditionToForMutator()(expr); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_polyfor_to_for.h b/paddle/cinn/optim/transform_polyfor_to_for.h new file mode 100644 index 0000000000000..d31bc6c4584f7 --- /dev/null +++ b/paddle/cinn/optim/transform_polyfor_to_for.h @@ -0,0 +1,32 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +//! Transform the PolyFor node to For node. This will also separate the PolyFor with Min or Max conditions into two For +//! nodes if \p auto_separate is true. +void TransformPolyForToFor(Expr* expr, bool auto_separate = true); + +namespace detail { + +void PolyForWithSimpleConditionToFor(Expr* expr); + +} // namespace detail + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/transform_polyfor_to_for_test.cc b/paddle/cinn/optim/transform_polyfor_to_for_test.cc new file mode 100644 index 0000000000000..d98dd770c4549 --- /dev/null +++ b/paddle/cinn/optim/transform_polyfor_to_for_test.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/transform_polyfor_to_for.h" + +#include + +#include "cinn/cinn.h" + +namespace cinn { +namespace optim { + +TEST(Expr, basic) { + using namespace ir; // NOLINT + + Expr M(512); + Expr K(200); + Expr N(500); + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + // C = A * B + Var k(K.as_int32(), "k0"); + + Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return lang::ReduceSum(A(i, k) * B(k, j), {k}); }, "C"); + + auto stages = CreateStages({C}); + + { + stages[C]->Split("i", 8); + stages[C]->Split("j", 8); + } + + // Code gen + auto func = Lower("matmul", stages, {A, B, C}); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + + { + ir::Module::Builder builder("module1", target); + builder.AddFunction(func); + + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "out:\n" << out; + } + + optim::TransformPolyForToFor(&func->body); + + { + ir::Module::Builder builder("module1", target); + builder.AddFunction(func); + + CodeGenC codegen(target); + codegen.SetInlineBuiltinCodes(false); + auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "out:\n" << out; + + auto target_out = R"ROC( +#include +#include + +void matmul(void* _args, int32_t num_args) +{ + const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); + const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); + cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); + cinn_buffer_malloc((void*)(0), _C); + const float* A = ((const float*)(_A->memory)); + const float* B = ((const float*)(_B->memory)); + float* C = ((float*)(_C->memory)); + float* C__reduce_init = ((float*)(_C->memory)); + for (int32_t i_outer = 0; i_outer < 64; i_outer += 1) { + for (int32_t i_inner = 0; i_inner < 8; i_inner += 1) { + for (int32_t j_outer = 0; j_outer < 63; j_outer += 1) { + for (int32_t j_inner = 0; j_inner < cinn_min(8, (500 + (-8 * j_outer))); j_inner += 1) { + C__reduce_init[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] = 0.00000000f; + for (int32_t k0 = 0; k0 < 200; k0 += 1) { + C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] = fma(A[((200 * i_inner) + ((1600 * i_outer) + k0))], B[((8 * j_outer) + ((500 * k0) + j_inner))], C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))]); + }; + }; + }; + }; + }; + cinn_buffer_free((void*)(0), _C); +} +)ROC"; + EXPECT_EQ(utils::Trim(target_out), utils::Trim(out)); + } +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/unroll_loops.cc b/paddle/cinn/optim/unroll_loops.cc new file mode 100755 index 0000000000000..7262a77878900 --- /dev/null +++ b/paddle/cinn/optim/unroll_loops.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/unroll_loops.h" + +#include +#include + +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_replace.h" + +namespace cinn { +namespace optim { + +namespace { + +struct UnrollMutator : public ir::IRMutator { + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + // update auto_max_step_ from the specific attribute of ScheduleBlock + void Visit(const ir::ScheduleBlock* op, Expr* expr) override { + auto attr_it = op->attrs.find(ir::attr::auto_unroll_max_step); + if (attr_it != op->attrs.end()) { + const int* attr_v = absl::get_if(&attr_it->second); + if (attr_v) { + int value = *attr_v; + std::swap(auto_max_step_, value); + VLOG(5) << "auto_max_step is updated:" << auto_max_step_; + ir::IRMutator<>::Visit(op, expr); + std::swap(auto_max_step_, value); + return; + } else { + LOG(WARNING) << "Get invalid value of attr:" << ir::attr::auto_unroll_max_step; + } + } + ir::IRMutator<>::Visit(op, expr); + } + + // count a Store node as plain statement + void Visit(const ir::Store* op, Expr* expr) override { + IRMutator<>::Visit(op, expr); + ++flat_step_; + } + + // predicate whether a for-loop can be unrolled and do it + void Visit(const ir::For* op, Expr* expr) override { + IRMutator<>::Visit(op, expr); + if (op->extent.As() == nullptr) { + VLOG(5) << "loop to be unrolled should have a contant extent"; + return; + } + int extent = op->extent.as_int32(); + + // predicate this for-loop can be unrolled by auto-unroll conditions + bool unrollable = + (op->is_serial() && extent >= 0 && not_unrolled_depth_ == 0 && extent * flat_step_ <= auto_max_step_); + + // predicate this for-loop can be unrolled by the unrolled tag + unrollable = (unrollable || op->is_unrolled()) && extent <= max_unroll_extent_; + + if (unrollable) { + Unroll(op, expr); + flat_step_ *= extent; + } else { + ++not_unrolled_depth_; + } + } + + //! Unroll a forloop. + void Unroll(const ir::For* op, Expr* expr) { + std::vector body; + + auto* min = op->min.As(); + auto* extent = op->extent.As(); + if (!(min && extent)) return; + + for (int i = min->value; i < extent->value; i++) { + Expr start = op->min + i; + body.push_back(optim::IRCopy(op->body)); + optim::IrReplace(&body.back(), op->loop_var, start); + } + + *expr = ir::Block::Make(body); + } + + private: + // max permitted steps to be automatically unrolled in total + int auto_max_step_ = 0; + // max permitted extent of a loop to be unrolled + int max_unroll_extent_ = 50; + + // the number of steps that have been unrolled or plain statement + int flat_step_ = 0; + // the number of nested loops not to be unrolled + int not_unrolled_depth_ = 0; +}; + +} // namespace + +void UnrollLoop(Expr* expr) { UnrollMutator()(expr); } + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/unroll_loops.h b/paddle/cinn/optim/unroll_loops.h new file mode 100644 index 0000000000000..283991b4f81dc --- /dev/null +++ b/paddle/cinn/optim/unroll_loops.h @@ -0,0 +1,24 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +void UnrollLoop(Expr* expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/unroll_loops_test.cc b/paddle/cinn/optim/unroll_loops_test.cc new file mode 100644 index 0000000000000..e4dbf49055da7 --- /dev/null +++ b/paddle/cinn/optim/unroll_loops_test.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/unroll_loops.h" + +#include + +#include + +#include "cinn/cinn.h" +#include "cinn/ir/ir_schedule.h" +#include "cinn/lang/lower.h" + +namespace cinn { +namespace optim { + +TEST(UnrollLoops, unrolled_tag) { + using namespace ir; + + Expr M(100); + Expr N(4); + + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); + + Tensor C = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C"); + + auto stages = CreateStages({C}); + + Target target = common::DefaultHostTarget(); + auto func = cinn::lang::LowerVec("test_unrolled_tag", stages, {A, B, C}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + + ir::ModuleExpr mod_expr({ast_expr}); + ir::IRSchedule ir_sch(mod_expr); + auto loops = ir_sch.GetLoops("C"); + ASSERT_EQ(loops.size(), 2U); + + // extent of the loop exceed the max permitted value in the unroll_loops pass, + // which currently set 50, so the loop can not be unrolled actually + loops[1].As()->extent.As()->value = 51; + ir_sch.Unroll(loops[1]); + UnrollLoop(&ast_expr); + loops = ir_sch.GetLoops("C"); + ASSERT_EQ(loops.size(), 2U); + + // unrolled correctly + loops[1].As()->extent.As()->value = 4; + UnrollLoop(&ast_expr); + EXPECT_EQ(ir_sch.GetLoops("C").size(), 1); +} + +TEST(UnrollLoops, auto_unroll) { + using namespace ir; + + Expr M(100); + Expr N(4); + Expr O(5); + Expr const_value(float(2.11)); + + Placeholder A("A", {M, N, O}); + + // B = A + 2.11 + Tensor B = Compute( + {M, N, O}, [&](Var i, Var j, Var k) { return A(i, j, k) + const_value; }, "B"); + + auto stages = CreateStages({B}); + Target target = common::DefaultHostTarget(); + auto func = cinn::lang::LowerVec("test_auto_unroll", stages, {A, B}, {}, {}, nullptr, target, true); + auto ast_expr = func[0]->body; + ir::ModuleExpr mod_expr({ast_expr}); + ir::IRSchedule ir_sch(mod_expr); + ASSERT_EQ(ir_sch.GetLoops("B").size(), 3); + UnrollLoop(&ast_expr); + // check after the last UnrollLoop pass it will remain unchanged + ASSERT_EQ(ir_sch.GetLoops("B").size(), 3); + + ASSERT_TRUE(ast_expr.As()->stmts.front().As() != nullptr); + auto* block_realize = ast_expr.As()->stmts.front().As(); + auto* schedule_block = block_realize->schedule_block.As(); + // set the 'auto_unroll_max_step' attribute as value 25 that is bigger than + // the product of extent of the inner 2 loops + schedule_block->attrs.emplace(ir::attr::auto_unroll_max_step, 25); + UnrollLoop(&ast_expr); + EXPECT_EQ(ir_sch.GetLoops("B").size(), 1); +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/var_mod_simplify.cc b/paddle/cinn/optim/var_mod_simplify.cc new file mode 100644 index 0000000000000..af099fe028391 --- /dev/null +++ b/paddle/cinn/optim/var_mod_simplify.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/var_mod_simplify.h" + +#include + +#include "cinn/common/cas.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_printer.h" + +namespace cinn::optim { + +namespace { +using namespace ir; // NOLINT + +struct ReplaceModWithDivMutator : public ir::IRMutator<> { + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + void Visit(const Mod* op, Expr* expr) override { + auto* node = expr->As(); + auto a = node->operand(0); + auto b = node->operand(1); + *expr = ir::Div::Make(a, b); + *expr = ir::Mul::Make(b, *expr); + *expr = ir::Sub::Make(a, *expr); + } +}; + +struct ReplaceDivWithVarMutator : public ir::IRMutator<> { + absl::flat_hash_map div_var_map_; + void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); } + + void Visit(const Div* op, Expr* expr) override { + auto* node = expr->As(); + + auto a = node->operand(0); + auto b = node->operand(1); + // only deal with var/int + if (a.is_var() && b.is_constant()) { + auto a_var = a.As<_Var_>(); + auto b_int = b.As(); + CHECK(a_var); + CHECK(b_int); + std::string var_name = a_var->name + "/" + std::to_string(b_int->value); + div_var_map_[var_name] = ir::Div::Make(a, b); + *expr = Var(var_name); + } + } +}; + +struct ReplaceVarWithDivMutator : public ir::IRMutator<> { + absl::flat_hash_map div_var_map_; + void operator()(Expr* x, const absl::flat_hash_map& div_var_map) { + div_var_map_ = div_var_map; + ir::IRMutator<>::Visit(x, x); + } + + void Visit(const _Var_* op, Expr* expr) override { + auto* node = expr->As<_Var_>(); + CHECK(node); + if (div_var_map_.count(node->name)) { + *expr = div_var_map_[node->name]; + } + } +}; + +} // namespace + +void VarModSimplify(Expr* e) { + *e = common::AutoSimplify(*e); + ReplaceModWithDivMutator()(e); + ReplaceDivWithVarMutator mutator; + mutator(e); + *e = common::AutoSimplify(*e); + auto div_var_map = mutator.div_var_map_; + ReplaceVarWithDivMutator()(e, mutator.div_var_map_); +} + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/var_mod_simplify.h b/paddle/cinn/optim/var_mod_simplify.h new file mode 100644 index 0000000000000..fb01e7e39215a --- /dev/null +++ b/paddle/cinn/optim/var_mod_simplify.h @@ -0,0 +1,32 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "cinn/ir/ir.h" + +/** simplify expressions with vars' div and mod. + * + * For example, input the code + * \code + * ((i_j_k_fused / 3) * 144) + (48 * (i_j_k_fused % 3)) + * \endcode + * + * with the `i_j_k_fused` set as var will be simplified to i_j_k_fused + * + */ +namespace cinn::optim { + +void VarModSimplify(Expr* e); + +} // namespace cinn::optim diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc new file mode 100644 index 0000000000000..fc6d97f1daf4b --- /dev/null +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -0,0 +1,890 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/optim/vectorize_loops.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "cinn/common/cas.h" +#include "cinn/common/ir_util.h" +#include "cinn/ir/collect_ir_nodes.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/optim/ir_copy.h" +#include "cinn/optim/ir_replace.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/optim/tensor_write_tell.h" +#include "cinn/optim/unroll_loops.h" +#include "cinn/utils/functional.h" + +namespace cinn { +namespace optim { +using namespace ir; // NOLINT +using common::make_const; +using common::make_one; +using common::make_zero; + +//! Widen an expression to the given number of lanes. +Expr Widen(Expr e, int lanes) { + if (e.type().lanes() == lanes) return e; + if (const ir::Broadcast *op = e.As()) { + if (lanes % op->lanes == 0) { + return ir::Broadcast::Make(op->value, lanes); + } + } + + CHECK_EQ(e.type().lanes(), 1) << "Cannot broadcast lanes from " << e.type().lanes() << " to " << lanes; + return ir::Broadcast::Make(e, lanes); +} + +// tell whether a tensor can be vectorized or not on CUDA by collecting names +// of tensors which meet all check predicates of vectoring +class TensorVectorizeTeller : public ir::IRMutator { + public: + TensorVectorizeTeller(const Var &iter_var, + const int factor, + const absl::flat_hash_map *var_intervals) + : iter_var_(iter_var), factor_(factor), var_intervals_(var_intervals) {} + + void Collect(const Expr *op) { IRMutator::Visit(op, op); } + + // return true if input tensor can be vectorized + bool CanBeVectorized(const std::string &tensor_name) const { + auto it = tensor2flag_.find(tensor_name); + return it != tensor2flag_.end() && it->second; + } + + private: + const Var iter_var_; // loop var of new for-loop split from the vectorized loop + const int factor_; + const absl::flat_hash_map *var_intervals_; + // save (tensor name) -> (bool flag) to indentify whether tensors can be vectorized or not + std::unordered_map tensor2flag_; + + void Visit(const ir::Store *expr, const Expr *op) override { + auto *node = op->As(); + CHECK(node); + IRMutator::Visit(&node->value, &node->value); + auto *tensor = node->tensor.As(); + CHECK(tensor); + + // a tensor should pass all check of pre-conditions in every time it appears + if (!tensor2flag_.count(tensor->name) || tensor2flag_.at(tensor->name)) { + bool flag = MeetConditions(node->tensor, node->indices); + tensor2flag_[tensor->name] = flag; + } + } + + void Visit(const ir::Load *expr, const Expr *op) override { + auto *node = op->As(); + CHECK(node); + auto *tensor = node->tensor.As(); + CHECK(tensor); + + // a tensor should pass all check of pre-conditions in every time it appears + if (!tensor2flag_.count(tensor->name) || tensor2flag_.at(tensor->name)) { + bool flag = MeetConditions(node->tensor, node->indices); + tensor2flag_[tensor->name] = flag; + } + } + + // return true if the tensor meets all conditions of vectorizing + bool MeetConditions(const Expr &expr, const std::vector &indices) { + const ir::_Tensor_ *tensor = expr.As(); + auto find_matched_var_fn = [&](const Expr *x) { return x->As<_Var_>() && x->As<_Var_>()->name == iter_var_->name; }; + + // the size of the last dim should be divisible by factor + Expr last_size = tensor->shape.back(); + if (tensor->shape.empty() || !tensor->shape.back().As() || tensor->shape.back().as_int32() % factor_ != 0) { + VLOG(5) << "Size of the last dim of tensor:" << tensor->name << " can't be divisible by factor:" << factor_ + << ", shape:" << utils::Join(tensor->shape, ","); + return false; + } + + // the iter val must appear in the last index + if (indices.empty() || ir::CollectIRNodes(indices.back(), find_matched_var_fn).empty()) { + VLOG(5) << "Loop var:" << iter_var_->name << " is not used in the last index"; + return false; + } + + // the iter val can't appear in mulitple indices + for (int i = 0; i < indices.size() - 1; ++i) { + auto repeat_found = ir::CollectIRNodes(indices[i], find_matched_var_fn); + if (!repeat_found.empty()) { + VLOG(5) << "Loop var:" << iter_var_->name << " is used at more than last index, current:" << i; + return false; + } + } + + // check tensor accessed sequentially by comparing index one by one + Expr first_idx = optim::IRCopy(indices.back()); + optim::IrReplace(&first_idx, Expr(iter_var_), Expr(0)); + const auto &interval = var_intervals_->at(iter_var_->name); + for (int i = 1; i < interval.r; ++i) { + Expr next_idx = optim::IRCopy(indices.back()); + optim::IrReplace(&next_idx, Expr(iter_var_), Expr(i)); + auto gap = common::AutoSimplify(Expr(next_idx - first_idx)); + if (!gap.As() || gap.as_int32() != i) { + VLOG(5) << "Tensor:" << tensor->name << " is not accessed sequentially, next:" << next_idx + << ", first:" << first_idx << ", gap:" << gap; + return false; + } + VLOG(5) << "Tensor:" << tensor->name << " is accessed sequentially, next:" << next_idx << ", first:" << first_idx + << ", gap:" << gap; + } + + auto dtype = expr->type().ElementOf(); + bool type_supported = + dtype.is_float(32) || dtype.is_int(32) || dtype.is_uint(32) || dtype.is_float16() || dtype.is_bfloat16(); + if (!type_supported) { + VLOG(5) << "Only support vectorizing int,uint,float,float16,bloat16, but got " << dtype; + return false; + } + return true; + } +}; + +// find tensors accessed sequentially in a for-loop to be vectorized, +// and substitue the corresponding cuda built-in vector for them +class CudaVectorizer : public IRMutator { + const Var iter_var_; // the loop var of the vecotrized loop + const int factor_; // the factor for vectorize + + TensorWriteTeller write_teller_; + TensorVectorizeTeller vectorized_teller_; + + absl::flat_hash_map tensor2vectorized_vars_; + std::vector vectorized_cast_exprs_; + std::vector vectorized_store_exprs_; + + public: + static constexpr int CudaVectorTypeMaxLanes = 8; + CudaVectorizer(const Var &iter_var, + const int factor, + const absl::flat_hash_map *var_intervals) + : iter_var_(iter_var), factor_(factor), vectorized_teller_(iter_var, factor, var_intervals) { + CHECK(factor <= CudaVectorTypeMaxLanes) + << "The maximum lanes of valid CUDA vector types: " << CudaVectorTypeMaxLanes << ", but factor: " << factor; + } + + // return all cast statements collected through vectorizing + std::vector VectorizedTypeCastExprs() { return vectorized_cast_exprs_; } + + // return all store statements collected through vectorizing + std::vector VectorizedTypeStoreExprs() { return vectorized_store_exprs_; } + + void Visit(Expr *expr) { + write_teller_.Collect(expr); + vectorized_teller_.Collect(expr); + IRMutator::Visit(expr, expr); + } + + void Visit(const Load *op, Expr *expr) override { + auto *node = expr->As(); + auto *tensor = node->tensor.As(); + if (node->is_addr_tensor() && vectorized_teller_.CanBeVectorized(tensor->name)) { + TensorVectorized(node, &node->indices, false); + } + } + + void Visit(const Store *op, Expr *expr) override { + auto *node = expr->As(); + auto *tensor = node->tensor.As(); + CHECK(tensor); + if (vectorized_teller_.CanBeVectorized(tensor->name)) { + TensorVectorized(node, &node->indices, true); + } + + IRMutator::Visit(&node->value, &node->value); + } + + private: + void TensorVectorized(ir::LoadStoreAddrMnger *node, std::vector *indices, bool is_store) { + auto *tensor = node->tensor.As(); + VLOG(5) << "Vectorizing tensor:" << tensor->name; + + // save the tensor and its corresponding vector name when it first appear + if (!tensor2vectorized_vars_.count(tensor->name)) { + AppendCast(node->tensor, *indices, is_store); + } + + auto vectorized_var = tensor2vectorized_vars_.at(tensor->name); + // substitue a new tensor with the vector name and dtype + auto t = vectorized_var->type().is_cpp_handle() ? node->tensor->type().PointerOf() : node->tensor->type(); + node->tensor = ir::Tensor(vectorized_var->name, t, {Expr(factor_)}, {Expr(factor_)}, tensor->operation); + // remain the last iterative indice + indices->assign({iter_var_}); + } + + std::string GetVectorTypeName(Type type) { + std::string name_prefix = common::customized_type::kcuda_builtin_vector_t; +#define GET_CUDA_VECTOR_TYPE_NAME(pred_expr, scalar_name) \ + if (pred_expr) { \ + return name_prefix + scalar_name + std::to_string(factor_); \ + } + + GET_CUDA_VECTOR_TYPE_NAME(type.is_int(32), "int"); + GET_CUDA_VECTOR_TYPE_NAME(type.is_uint(32), "uint"); + GET_CUDA_VECTOR_TYPE_NAME(type.is_float(32), "float"); + GET_CUDA_VECTOR_TYPE_NAME(type.is_float16(), "half"); + GET_CUDA_VECTOR_TYPE_NAME(type.is_bfloat16(), "bfloat16"); +#undef GET_CUDA_VECTOR_TYPE_NAME + + // others are not implementd yet + CINN_NOT_IMPLEMENTED + return ""; + } + + void AppendCast(Expr tensor, const std::vector &indices, bool is_store) { + auto *node = tensor.As(); + bool is_const = !write_teller_.IsWrite(node->name); + + // generate the corresponding vector type + Type scalar_type = tensor->type().ElementOf(); + Type vector_type_ptr(Type::type_t::Customized, scalar_type.bits(), factor_); + Type vector_type(Type::type_t::Customized, scalar_type.bits(), factor_); + vector_type_ptr.set_customized_type(GetVectorTypeName(scalar_type)); + vector_type_ptr.set_cpp_handle(); + vector_type_ptr.set_cpp_const(is_const); + + vector_type.set_customized_type(GetVectorTypeName(scalar_type)); + vector_type.set_cpp_const(is_const); + + // generate a local vector variable to be used in subsequent statements + std::string vectorized_name = "vectorized_" + node->name; + Var vectorized_var = _Var_::Make(vectorized_name, vector_type); + tensor2vectorized_vars_.emplace(node->name, vectorized_var); + + // generate a get_addr expr to get the address of the tensor + Expr converted_tensor = Load::Make(tensor, indices); + optim::IrReplace(&converted_tensor, iter_var_, Expr(int32_t(0))); + auto get_addr = ir::intrinsics::GetAddr::Make(converted_tensor); + + // generate a let expression to cast the tensor into the local vector + auto cast = ir::Cast::Make(vector_type_ptr, get_addr); + if (!is_store) { + auto load = Load::Make(cast, {make_const(0)}); + auto let = Let::Make(vectorized_var, load); + vectorized_cast_exprs_.emplace_back(let); + VLOG(5) << "Append a vectorized expr:" << let; + } else { + Var vectorized_ptr = _Var_::Make(vectorized_name + "_ptr", vector_type_ptr); + + auto let1 = Let::Make(vectorized_ptr, cast); + auto let2 = Let::Make(vectorized_var, Expr(0)); + vectorized_cast_exprs_.emplace_back(let1); + vectorized_cast_exprs_.emplace_back(let2); + + VLOG(5) << "Append a vectorized expr:" << let1; + VLOG(5) << "Append a vectorized expr:" << let2; + + auto t = + ir::Tensor(vectorized_ptr->name, node->type().PointerOf(), {Expr(factor_)}, {Expr(factor_)}, node->operation); + auto store = Store::Make(t, vectorized_var, {make_const(0)}); + + vectorized_store_exprs_.emplace_back(store); + VLOG(5) << "Append a vectorized expr:" << store; + } + } +}; + +//! Substitutes a vector for a scalar var in a Stmt. +class Vectorizer : public IRMutator { + //! The name of the variable to be vectorized. + Var var; + + int lanes_{-1}; + + bool need_scalarize_{false}; + + bool to_vectorize_{false}; + + Expr ramp_; + + absl::flat_hash_map var_intervals_; + + //! A suffix to attach to widened variables. + std::string widen_suffix; + + public: + Vectorizer(const Var &var, int lanes, const absl::flat_hash_map &var_intervals = {}) + : var(var), lanes_(lanes), var_intervals_(var_intervals) { + // the identity ramp. + ramp_ = Ramp::Make(make_zero(), make_one(), lanes_); + } + + void Visit(Expr *expr) { + CHECK(!need_scalarize_); + IRMutator::Visit(expr, expr); + + if (need_scalarize_) { + need_scalarize_ = false; + Scalarize(expr); + } + } + + void Visit(const Cast *op, Expr *expr) override { + auto *node = expr->As(); + auto v0 = node->v(); + Visit(&node->v()); + if (v0.same_as(node->v())) return; + + Type t = op->type().with_lanes(node->v().type().lanes()); + node->set_type(t); + } + + void Visit(const _Var_ *op, Expr *expr) override { + if (op->name == var->name) { + *expr = Expr(ramp_); + } + } + + void Visit(const Add *op, Expr *expr) override { MutateAddSubOperator(op, expr); } + void Visit(const Sub *op, Expr *expr) override { MutateAddSubOperator(op, expr); } + void Visit(const Mul *op, Expr *expr) override { MutateMulDivOperator(op, expr); } + void Visit(const Div *op, Expr *expr) override { MutateMulDivOperator(op, expr); } + void Visit(const Mod *op, Expr *expr) override { MutateMulDivOperator(op, expr); } + void Visit(const Min *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const Max *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const EQ *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const NE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const LT *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const LE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const GT *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const GE *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const And *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + void Visit(const Or *op, Expr *expr) override { BinaryOperatorVec(op, expr); } + + void Visit(const Ramp *op, Expr *expr) override {} + + void Visit(const Select *op, Expr *expr) override { + auto *node = expr->As